| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850 |
- import os
- import time
- import math
- import copy
- from functools import partial
- from typing import Optional, Callable, Any
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from timm.models.layers import DropPath, trunc_normal_
- from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
- DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
- # train speed is slower after enabling this opts.
- # torch.backends.cudnn.enabled = True
- # torch.backends.cudnn.benchmark = True
- # torch.backends.cudnn.deterministic = True
- try:
- from .csm_triton import cross_scan_fn, cross_merge_fn
- except:
- from csm_triton import cross_scan_fn, cross_merge_fn
- try:
- from .csms6s import selective_scan_fn, selective_scan_flop_jit
- except:
- from csms6s import selective_scan_fn, selective_scan_flop_jit
- # FLOPs counter not prepared fro mamba2
- try:
- from .mamba2.ssd_minimal import selective_scan_chunk_fn
- except:
- from mamba2.ssd_minimal import selective_scan_chunk_fn
- # =====================================================
- # we have this class as linear and conv init differ from each other
- # this function enable loading from both conv2d or linear
- class Linear2d(nn.Linear):
- def forward(self, x: torch.Tensor):
- # B, C, H, W = x.shape
- return F.conv2d(x, self.weight[:, :, None, None], self.bias)
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
- state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape)
- return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- class LayerNorm2d(nn.LayerNorm):
- def forward(self, x: torch.Tensor):
- x = x.permute(0, 2, 3, 1)
- x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = x.permute(0, 3, 1, 2)
- return x
- class PatchMerging2D(nn.Module):
- def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False):
- super().__init__()
- self.dim = dim
- Linear = Linear2d if channel_first else nn.Linear
- self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last
- self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
- self.norm = norm_layer(4 * dim)
- @staticmethod
- def _patch_merging_pad_channel_last(x: torch.Tensor):
- H, W, _ = x.shape[-3:]
- if (W % 2 != 0) or (H % 2 != 0):
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
- x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
- x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
- x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
- return x
- @staticmethod
- def _patch_merging_pad_channel_first(x: torch.Tensor):
- H, W = x.shape[-2:]
- if (W % 2 != 0) or (H % 2 != 0):
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[..., 0::2, 0::2] # ... H/2 W/2
- x1 = x[..., 1::2, 0::2] # ... H/2 W/2
- x2 = x[..., 0::2, 1::2] # ... H/2 W/2
- x3 = x[..., 1::2, 1::2] # ... H/2 W/2
- x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C
- return x
- def forward(self, x):
- x = self._patch_merging_pad(x)
- x = self.norm(x)
- x = self.reduction(x)
- return x
- class Permute(nn.Module):
- def __init__(self, *args):
- super().__init__()
- self.args = args
- def forward(self, x: torch.Tensor):
- return x.permute(*self.args)
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- Linear = Linear2d if channels_first else nn.Linear
- self.fc1 = Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class gMlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
- super().__init__()
- self.channel_first = channels_first
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- Linear = Linear2d if channels_first else nn.Linear
- self.fc1 = Linear(in_features, 2 * hidden_features)
- self.act = act_layer()
- self.fc2 = Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x: torch.Tensor):
- x = self.fc1(x)
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1))
- x = self.fc2(x * self.act(z))
- x = self.drop(x)
- return x
- class SoftmaxSpatial(nn.Softmax):
- def forward(self, x: torch.Tensor):
- if self.dim == -1:
- B, C, H, W = x.shape
- return super().forward(x.view(B, C, -1)).view(B, C, H, W)
- elif self.dim == 1:
- B, H, W, C = x.shape
- return super().forward(x.view(B, -1, C)).view(B, H, W, C)
- else:
- raise NotImplementedError
- # =====================================================
- class mamba_init:
- @staticmethod
- def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
- dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
- # Initialize special dt projection to preserve variance at initialization
- dt_init_std = dt_rank**-0.5 * dt_scale
- if dt_init == "constant":
- nn.init.constant_(dt_proj.weight, dt_init_std)
- elif dt_init == "random":
- nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
- else:
- raise NotImplementedError
- # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
- dt = torch.exp(
- torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
- + math.log(dt_min)
- ).clamp(min=dt_init_floor)
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
- inv_dt = dt + torch.log(-torch.expm1(-dt))
- with torch.no_grad():
- dt_proj.bias.copy_(inv_dt)
- # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
- # dt_proj.bias._no_reinit = True
-
- return dt_proj
- @staticmethod
- def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
- # S4D real initialization
- A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
- A_log = torch.log(A) # Keep A_log in fp32
- if copies > 0:
- A_log = A_log[None].repeat(copies, 1, 1).contiguous()
- if merge:
- A_log = A_log.flatten(0, 1)
- A_log = nn.Parameter(A_log)
- A_log._no_weight_decay = True
- return A_log
- @staticmethod
- def D_init(d_inner, copies=-1, device=None, merge=True):
- # D "skip" parameter
- D = torch.ones(d_inner, device=device)
- if copies > 0:
- D = D[None].repeat(copies, 1).contiguous()
- if merge:
- D = D.flatten(0, 1)
- D = nn.Parameter(D) # Keep in fp32
- D._no_weight_decay = True
- return D
- @classmethod
- def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
- # dt proj ============================
- dt_projs = [
- cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
- for _ in range(k_group)
- ]
- dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
- dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
- del dt_projs
-
- # A, D =======================================
- A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
- Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
- return A_logs, Ds, dt_projs_weight, dt_projs_bias
- # support: v0, v0seq
- class SS2Dv0:
- def __initv0__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- # ======================
- dropout=0.0,
- # ======================
- seq=False,
- force_fp32=True,
- **kwargs,
- ):
- if "channel_first" in kwargs:
- assert not kwargs["channel_first"]
- act_layer = nn.SiLU
- dt_min = 0.001
- dt_max = 0.1
- dt_init = "random"
- dt_scale = 1.0
- dt_init_floor = 1e-4
- bias = False
- conv_bias = True
- d_conv = 3
- k_group = 4
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- self.forward = self.forwardv0
- if seq:
- self.forward = partial(self.forwardv0, seq=True)
- if not force_fp32:
- self.forward = partial(self.forwardv0, force_fp32=False)
- # in proj ============================
- self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
- self.act: nn.Module = act_layer()
- self.conv2d = nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- )
- # x proj ============================
- self.x_proj = [
- nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
- for _ in range(k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
- # dt proj, A, D ============================
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
- )
- # out proj =======================================
- self.out_norm = nn.LayerNorm(d_inner)
- self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
- x = self.in_proj(x)
- x, z = x.chunk(2, dim=-1) # (b, h, w, d)
- z = self.act(z)
- x = x.permute(0, 3, 1, 2).contiguous()
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- selective_scan = partial(selective_scan_fn, backend="mamba")
-
- B, D, H, W = x.shape
- D, N = self.A_logs.shape
- K, D, R = self.dt_projs_weight.shape
- L = H * W
- x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
- xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
- if hasattr(self, "x_proj_bias"):
- x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
- xs = xs.view(B, -1, L) # (b, k * d, l)
- dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
- Bs = Bs.contiguous() # (b, k, d_state, l)
- Cs = Cs.contiguous() # (b, k, d_state, l)
-
- As = -self.A_logs.float().exp() # (k * d, d_state)
- Ds = self.Ds.float() # (k * d)
- dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
- # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
- # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
-
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- if seq:
- out_y = []
- for i in range(4):
- yi = selective_scan(
- xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
- As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
- delta_bias=dt_projs_bias.view(K, -1)[i],
- delta_softplus=True,
- ).view(B, -1, L)
- out_y.append(yi)
- out_y = torch.stack(out_y, dim=1)
- else:
- out_y = selective_scan(
- xs, dts,
- As, Bs, Cs, Ds,
- delta_bias=dt_projs_bias,
- delta_softplus=True,
- ).view(B, K, -1, L)
- assert out_y.dtype == torch.float
- inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
- wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
-
- y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
- y = self.out_norm(y).view(B, H, W, -1)
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- # support: v01-v05; v051d,v052d,v052dc;
- # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
- # history support: v2,v3;v31d,v32d,v32dc;
- class SS2Dv2:
- def __initv2__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.SiLU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- self.k_group = 4
- self.d_model = int(d_model)
- self.d_state = int(d_state)
- self.d_inner = int(ssm_ratio * d_model)
- self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
- self.channel_first = channel_first
- self.with_dconv = d_conv > 1
- Linear = Linear2d if channel_first else nn.Linear
- self.forward = self.forwardv2
- # tags for forward_type ==============================
- checkpostfix = self.checkpostfix
- self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
- self.oact, forward_type = checkpostfix("_oact", forward_type)
- self.disable_z, forward_type = checkpostfix("_noz", forward_type)
- self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
- self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
- # forward_type debug =======================================
- FORWARD_TYPES = dict(
- v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
- v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
- v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
- v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
- v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
- # ===============================
- v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
- v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
- v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
- v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
- # ===============================
- v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
- v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
- )
- self.forward_core = FORWARD_TYPES.get(forward_type, None)
- # in proj =======================================
- d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
- self.in_proj = Linear(self.d_model, d_proj, bias=bias)
- self.act: nn.Module = act_layer()
-
- # conv =======================================
- if self.with_dconv:
- self.conv2d = nn.Conv2d(
- in_channels=self.d_inner,
- out_channels=self.d_inner,
- groups=self.d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- )
- # x proj ============================
- self.x_proj = [
- nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
- for _ in range(self.k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
-
- # out proj =======================================
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- self.out_proj = Linear(self.d_inner, self.d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- if initialize in ["v0"]:
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
- )
- elif initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
- self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
- self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
- self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
- def forward_corev2(
- self,
- x: torch.Tensor=None,
- # ==============================
- force_fp32=False, # True: input fp32
- # ==============================
- ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
- no_einsum=False, # replace einsum with linear or conv1d to raise throughput
- # ==============================
- selective_scan_backend = None,
- # ==============================
- scan_mode = "cross2d",
- scan_force_torch = False,
- # ==============================
- **kwargs,
- ):
- assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
- _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
- assert isinstance(_scan_mode, int)
- delta_softplus = True
- out_norm = self.out_norm
- channel_first = self.channel_first
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- B, D, H, W = x.shape
- N = self.d_state
- K, D, R = self.k_group, self.d_inner, self.dt_rank
- L = H * W
- def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
- return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
-
- if _scan_mode == -1:
- x_proj_bias = getattr(self, "x_proj_bias", None)
- def scan_rowcol(
- x: torch.Tensor,
- proj_weight: torch.Tensor,
- proj_bias: torch.Tensor,
- dt_weight: torch.Tensor,
- dt_bias: torch.Tensor, # (2*c)
- _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state)
- _Ds: torch.Tensor,
- width = True,
- ):
- # x: (B, D, H, W)
- # proj_weight: (2 * D, (R+N+N))
- XB, XD, XH, XW = x.shape
- if width:
- _B, _D, _L = XB * XH, XD, XW
- xs = x.permute(0, 2, 1, 3).contiguous()
- else:
- _B, _D, _L = XB * XW, XD, XH
- xs = x.permute(0, 3, 1, 2).contiguous()
- xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W)
- if no_einsum:
- x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2)
- dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2)
- dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2)
- else:
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight)
- xs = xs.view(_B, -1, _L)
- dts = dts.contiguous().view(_B, -1, _L)
- As = _As.view(-1, N).to(torch.float)
- Bs = Bs.contiguous().view(_B, 2, N, _L)
- Cs = Cs.contiguous().view(_B, 2, N, _L)
- Ds = _Ds.view(-1)
- delta_bias = dt_bias.view(-1).to(torch.float)
- if force_fp32:
- xs = xs.to(torch.float)
- dts = dts.to(xs.dtype)
- Bs = Bs.to(xs.dtype)
- Cs = Cs.to(xs.dtype)
- ys: torch.Tensor = selective_scan(
- xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(_B, 2, -1, _L)
- return ys
-
- As = -self.A_logs.to(torch.float).exp().view(4, -1, N)
- x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
- y_row = scan_rowcol(
- x,
- proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(),
- proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None),
- dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(),
- dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None),
- _As = As[:2].contiguous().view(-1, N),
- _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1),
- width=True,
- ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W)
- y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
- y_col = scan_rowcol(
- y_row,
- proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype),
- proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None),
- dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype),
- dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None),
- _As = As[2:].contiguous().view(-1, N),
- _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1),
- width=False,
- ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1)
- y = y_col
- else:
- x_proj_bias = getattr(self, "x_proj_bias", None)
- xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
- if no_einsum:
- x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
- dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
- if hasattr(self, "dt_projs_weight"):
- dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K)
- else:
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- if hasattr(self, "dt_projs_weight"):
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
- xs = xs.view(B, -1, L)
- dts = dts.contiguous().view(B, -1, L)
- As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
- Ds = self.Ds.to(torch.float) # (K * c)
- Bs = Bs.contiguous().view(B, K, N, L)
- Cs = Cs.contiguous().view(B, K, N, L)
- delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- ys: torch.Tensor = selective_scan(
- xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(B, K, -1, H, W)
-
- y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
- us=xs, dts=dts, delta_bias=delta_bias,
- ys=ys, y=y, H=H, W=W,
- ))
- y = y.view(B, -1, H, W)
- if not channel_first:
- y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C)
- y = out_norm(y)
- return y.to(x.dtype)
- def forwardv2(self, x: torch.Tensor, **kwargs):
- x = self.in_proj(x)
- if not self.disable_z:
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
- if not self.disable_z_act:
- z = self.act(z)
- if not self.channel_first:
- x = x.permute(0, 3, 1, 2).contiguous()
- if self.with_dconv:
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- y = self.forward_core(x)
- y = self.out_act(y)
- if not self.disable_z:
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- @staticmethod
- def get_outnorm(forward_type="", d_inner=192, channel_first=True):
- def checkpostfix(tag, value):
- ret = value[-len(tag):] == tag
- if ret:
- value = value[:-len(tag)]
- return ret, value
- LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm
- out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
- out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
- out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
- out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
- out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
- out_norm = nn.Identity()
- if out_norm_none:
- out_norm = nn.Identity()
- elif out_norm_cnorm:
- out_norm = nn.Sequential(
- LayerNorm(d_inner),
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- )
- elif out_norm_dwconv3:
- out_norm = nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- )
- elif out_norm_softmax:
- out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
- elif out_norm_sigmoid:
- out_norm = nn.Sigmoid()
- else:
- out_norm = LayerNorm(d_inner)
- return out_norm, forward_type
- @staticmethod
- def checkpostfix(tag, value):
- ret = value[-len(tag):] == tag
- if ret:
- value = value[:-len(tag)]
- return ret, value
- # support: xv1a,xv2a,xv3a;
- # postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone;
- class SS2Dv3:
- def __initxv__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- self.channel_first = channel_first
- self.d_state = d_state
- self.dt_rank = dt_rank
- self.d_inner = d_inner
- k_group = 4
- self.with_dconv = d_conv > 1
- Linear = Linear2d if channel_first else nn.Linear
- self.forward = self.forwardxv
- # tags for forward_type ==============================
- checkpostfix = SS2Dv2.checkpostfix
- self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first)
- self.omul, forward_type = checkpostfix("_mul", forward_type)
- self.oact, forward_type = checkpostfix("_act", forward_type)
- self.f_omul = nn.Identity() if self.omul else None
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- mode = forward_type[:4]
- assert mode in ["xv1a", "xv2a", "xv3a"]
- self.forward = partial(self.forwardxv, mode=mode)
- self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode]
- d_inner_all = d_inner + self.dts_dim + 8 * d_state
- self.in_proj = Linear(d_model, d_inner_all, bias=bias)
-
- # conv =======================================
- self.cpos = False
- self.iconv = False
- self.oconv = False
- self.oconv2 = False
- if self.with_dconv:
- cact, forward_type = checkpostfix("_ca", forward_type)
- cact1, forward_type = checkpostfix("_ca1", forward_type)
- self.cact = nn.SiLU() if cact else nn.Identity()
- self.cact = nn.GELU() if cact1 else self.cact
-
- self.oconv2, forward_type = checkpostfix("_ocov2", forward_type)
- self.oconv, forward_type = checkpostfix("_ocov", forward_type)
- self.cpos, forward_type = checkpostfix("_cpos", forward_type)
- self.iconv = (not self.oconv) and (not self.oconv2)
- if self.iconv:
- self.conv2d = nn.Conv2d(
- in_channels=d_model,
- out_channels=d_model,
- groups=d_model,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- if self.oconv:
- self.oconv2d = nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- if self.oconv2:
- self.conv2d = nn.Conv2d(
- in_channels=d_inner_all,
- out_channels=d_inner_all,
- groups=d_inner_all,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- # out proj =======================================
- self.out_proj = Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
- if initialize in ["v0"]:
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
- )
- elif initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
- self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
- self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
- self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank)))
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner)))
- if forward_type.startswith("xv2"):
- del self.dt_projs_weight
- self.dt_projs_weight = None
- def forwardxv(self, x: torch.Tensor, **kwargs):
- B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3])
- L = H * W
- force_fp32 = False
- delta_softplus = True
- out_norm = self.out_norm
- to_dtype = True
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):
- return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None)
- if self.iconv:
- x = self.cact(self.conv2d(x)) # (b, d, h, w)
- elif self.cpos:
- x = x + self.conv2d(x) # (b, d, h, w)
- x = self.in_proj(x)
-
- if self.oconv2:
- x = self.conv2d(x) # (b, d, h, w)
- us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1))
- _us = us
- # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1)
- # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W)
- us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L)
- Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
- Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
- dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1)
- if self.dts_dim == self.dt_rank:
- dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
- elif self.dts_dim == 4 * self.dt_rank:
- dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
- As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
- Ds = self.Ds.to(torch.float) # (K * c)
- delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c)
- if force_fp32:
- us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs)
- ys: torch.Tensor = selective_scan(
- us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(B, 4, -1, H, W)
- y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True)
- y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1)
- y = out_norm(y)
-
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
- us=us, dts=dts, delta_bias=delta_bias,
- ys=ys, y=y,
- ))
- y = (y.to(x.dtype) if to_dtype else y)
-
- y = self.out_act(y)
-
- if self.omul:
- y = y * _us
- if self.oconv:
- y = y + self.cact(self.oconv2d(_us))
- out = self.dropout(self.out_proj(y))
- return out
- # mamba2 support ================================
- class SS2Dm0:
- def __initm0__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16, # now with mamba2, dstate should be bigger...
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.GELU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v2",
- # ======================
- forward_type="m0",
- # ======================
- with_initial_state=False,
- # ======================
- **kwargs,
- ):
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- assert d_inner % dt_rank == 0
- self.with_dconv = d_conv > 1
- Linear = nn.Linear
- self.forward = self.forwardm0
- # tags for forward_type ==============================
- checkpostfix = SS2Dv2.checkpostfix
- self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
- self.oact, forward_type = checkpostfix("_oact", forward_type)
- self.disable_z, forward_type = checkpostfix("_noz", forward_type)
- self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
- self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False)
- # forward_type debug =======================================
- FORWARD_TYPES = dict(
- m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state),
- )
- self.forward_core = FORWARD_TYPES.get(forward_type, None)
- k_group = 4
- # in proj =======================================
- d_proj = d_inner if self.disable_z else (d_inner * 2)
- self.in_proj = Linear(d_model, d_proj, bias=bias)
- self.act: nn.Module = act_layer()
-
- # conv =======================================
- if self.with_dconv:
- self.conv2d = nn.Sequential(
- Permute(0, 3, 1, 2),
- nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- ),
- Permute(0, 2, 3, 1),
- )
-
- # x proj ============================
- self.x_proj = [
- nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
- for _ in range(k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
-
- # out proj =======================================
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- self.out_proj = Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- if initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
- self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
- self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank)))
- # init state ============================
- self.initial_state = None
- if with_initial_state:
- self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False)
- def forward_corem0(
- self,
- x: torch.Tensor=None,
- # ==============================
- force_fp32=False, # True: input fp32
- chunk_size = 64,
- dstate = 64,
- # ==============================
- selective_scan_backend = None,
- scan_mode = "cross2d",
- scan_force_torch = False,
- # ==============================
- **kwargs,
- ):
- assert scan_mode in ["unidi", "bidi", "cross2d"]
- assert selective_scan_backend in [None, "triton", "torch"]
- x_proj_bias = getattr(self, "x_proj_bias", None)
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- N = dstate
- B, H, W, RD = x.shape
- K, R = self.A_logs.shape
- K, R, D = self.Ds.shape
- assert RD == R * D
- L = H * W
- KR = K * R
- _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode]
- initial_state = None
- if self.initial_state is not None:
- assert self.initial_state.shape[-1] == dstate
- initial_state = self.initial_state.detach().repeat(B, 1, 1, 1)
- xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D)
- x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3)
- xs = xs.contiguous().view(B, L, KR, D)
- dts = dts.contiguous().view(B, L, KR)
- Bs = Bs.contiguous().view(B, L, K, N)
- Cs = Cs.contiguous().view(B, L, K, N)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- As = -self.A_logs.to(torch.float).exp().view(KR)
- Ds = self.Ds.to(torch.float).view(KR, D)
- dt_bias = self.dt_projs_bias.view(KR)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- ys, final_state = selective_scan_chunk_fn(
- xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias,
- initial_states=initial_state, dt_softplus=True, return_final_states=True,
- backend=selective_scan_backend,
- )
- y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch)
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds,
- us=xs, dts=dts, delta_bias=self.dt_projs_bias,
- initial_state=self.initial_state, final_satte=final_state,
- ys=ys, y=y, H=H, W=W,
- ))
- if self.initial_state is not None:
- self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False)
- y = self.out_norm(y.view(B, H, W, -1))
- return y.to(x.dtype)
- def forwardm0(self, x: torch.Tensor, **kwargs):
- x = self.in_proj(x)
- if not self.disable_z:
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
- if not self.disable_z_act:
- z = self.act(z)
- if self.with_dconv:
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- y = self.forward_core(x)
- y = self.out_act(y)
- if not self.disable_z:
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0):
- def __init__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.SiLU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- nn.Module.__init__(self)
- kwargs.update(
- d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
- act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
- dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
- initialize=initialize, forward_type=forward_type, channel_first=channel_first,
- )
- if forward_type in ["v0", "v0seq"]:
- self.__initv0__(seq=("seq" in forward_type), **kwargs)
- elif forward_type.startswith("xv"):
- self.__initxv__(**kwargs)
- elif forward_type.startswith("m"):
- self.__initm0__(**kwargs)
- else:
- self.__initv2__(**kwargs)
- # =====================================================
- class VSSBlock(nn.Module):
- def __init__(
- self,
- hidden_dim: int = 0,
- drop_path: float = 0,
- norm_layer: nn.Module = nn.LayerNorm,
- channel_first=False,
- # =============================
- ssm_d_state: int = 16,
- ssm_ratio=2.0,
- ssm_dt_rank: Any = "auto",
- ssm_act_layer=nn.SiLU,
- ssm_conv: int = 3,
- ssm_conv_bias=True,
- ssm_drop_rate: float = 0,
- ssm_init="v0",
- forward_type="v2",
- # =============================
- mlp_ratio=4.0,
- mlp_act_layer=nn.GELU,
- mlp_drop_rate: float = 0.0,
- gmlp=False,
- # =============================
- use_checkpoint: bool = False,
- post_norm: bool = False,
- # =============================
- _SS2D: type = SS2D,
- **kwargs,
- ):
- super().__init__()
- self.ssm_branch = ssm_ratio > 0
- self.mlp_branch = mlp_ratio > 0
- self.use_checkpoint = use_checkpoint
- self.post_norm = post_norm
- if self.ssm_branch:
- self.norm = norm_layer(hidden_dim)
- self.op = _SS2D(
- d_model=hidden_dim,
- d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- dt_rank=ssm_dt_rank,
- act_layer=ssm_act_layer,
- # ==========================
- d_conv=ssm_conv,
- conv_bias=ssm_conv_bias,
- # ==========================
- dropout=ssm_drop_rate,
- # bias=False,
- # ==========================
- # dt_min=0.001,
- # dt_max=0.1,
- # dt_init="random",
- # dt_scale="random",
- # dt_init_floor=1e-4,
- initialize=ssm_init,
- # ==========================
- forward_type=forward_type,
- channel_first=channel_first,
- )
-
- self.drop_path = DropPath(drop_path)
-
- if self.mlp_branch:
- _MLP = Mlp if not gmlp else gMlp
- self.norm2 = norm_layer(hidden_dim)
- mlp_hidden_dim = int(hidden_dim * mlp_ratio)
- self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first)
- def _forward(self, input: torch.Tensor):
- x = input
- if self.ssm_branch:
- if self.post_norm:
- x = x + self.drop_path(self.norm(self.op(x)))
- else:
- x = x + self.drop_path(self.op(self.norm(x)))
- if self.mlp_branch:
- if self.post_norm:
- x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
- else:
- x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
- return x
- def forward(self, input: torch.Tensor):
- if self.use_checkpoint:
- return checkpoint.checkpoint(self._forward, input)
- else:
- return self._forward(input)
- class VSSM(nn.Module):
- def __init__(
- self,
- patch_size=4,
- in_chans=3,
- num_classes=1000,
- depths=[2, 2, 9, 2],
- dims=[96, 192, 384, 768],
- # =========================
- ssm_d_state=16,
- ssm_ratio=2.0,
- ssm_dt_rank="auto",
- ssm_act_layer="silu",
- ssm_conv=3,
- ssm_conv_bias=True,
- ssm_drop_rate=0.0,
- ssm_init="v0",
- forward_type="v2",
- # =========================
- mlp_ratio=4.0,
- mlp_act_layer="gelu",
- mlp_drop_rate=0.0,
- gmlp=False,
- # =========================
- drop_path_rate=0.1,
- patch_norm=True,
- norm_layer="LN", # "BN", "LN2D"
- downsample_version: str = "v2", # "v1", "v2", "v3"
- patchembed_version: str = "v1", # "v1", "v2"
- use_checkpoint=False,
- # =========================
- posembed=False,
- imgsize=224,
- _SS2D=SS2D,
- # =========================
- **kwargs,
- ):
- super().__init__()
- self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
- self.num_classes = num_classes
- self.num_layers = len(depths)
- if isinstance(dims, int):
- dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
- self.num_features = dims[-1]
- self.dims = dims
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- _NORMLAYERS = dict(
- ln=nn.LayerNorm,
- ln2d=LayerNorm2d,
- bn=nn.BatchNorm2d,
- )
- _ACTLAYERS = dict(
- silu=nn.SiLU,
- gelu=nn.GELU,
- relu=nn.ReLU,
- sigmoid=nn.Sigmoid,
- )
- norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
- ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
- mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
- self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
- _make_patch_embed = dict(
- v1=self._make_patch_embed,
- v2=self._make_patch_embed_v2,
- ).get(patchembed_version, None)
- self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first)
- _make_downsample = dict(
- v1=PatchMerging2D,
- v2=self._make_downsample,
- v3=self._make_downsample_v3,
- none=(lambda *_, **_k: None),
- ).get(downsample_version, None)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- downsample = _make_downsample(
- self.dims[i_layer],
- self.dims[i_layer + 1],
- norm_layer=norm_layer,
- channel_first=self.channel_first,
- ) if (i_layer < self.num_layers - 1) else nn.Identity()
- self.layers.append(self._make_layer(
- dim = self.dims[i_layer],
- drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- use_checkpoint=use_checkpoint,
- norm_layer=norm_layer,
- downsample=downsample,
- channel_first=self.channel_first,
- # =================
- ssm_d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- ssm_dt_rank=ssm_dt_rank,
- ssm_act_layer=ssm_act_layer,
- ssm_conv=ssm_conv,
- ssm_conv_bias=ssm_conv_bias,
- ssm_drop_rate=ssm_drop_rate,
- ssm_init=ssm_init,
- forward_type=forward_type,
- # =================
- mlp_ratio=mlp_ratio,
- mlp_act_layer=mlp_act_layer,
- mlp_drop_rate=mlp_drop_rate,
- gmlp=gmlp,
- # =================
- _SS2D=_SS2D,
- ))
- self.classifier = nn.Sequential(OrderedDict(
- norm=norm_layer(self.num_features), # B,H,W,C
- permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
- avgpool=nn.AdaptiveAvgPool2d(1),
- flatten=nn.Flatten(1),
- head=nn.Linear(self.num_features, num_classes),
- ))
- self.apply(self._init_weights)
- @staticmethod
- def _pos_embed(embed_dims, patch_size, img_size):
- patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
- pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
- trunc_normal_(pos_embed, std=0.02)
- return pos_embed
- def _init_weights(self, m: nn.Module):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- # used in building optimizer
- @torch.jit.ignore
- def no_weight_decay(self):
- return {"pos_embed"}
- # used in building optimizer
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {}
- @staticmethod
- def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim) if patch_norm else nn.Identity()),
- )
- @staticmethod
- def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- stride = patch_size // 2
- kernel_size = stride + 1
- padding = 1
- return nn.Sequential(
- nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
- (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
- (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)),
- nn.GELU(),
- nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim) if patch_norm else nn.Identity()),
- )
-
- @staticmethod
- def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- norm_layer(out_dim),
- )
- @staticmethod
- def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- norm_layer(out_dim),
- )
- @staticmethod
- def _make_layer(
- dim=96,
- drop_path=[0.1, 0.1],
- use_checkpoint=False,
- norm_layer=nn.LayerNorm,
- downsample=nn.Identity(),
- channel_first=False,
- # ===========================
- ssm_d_state=16,
- ssm_ratio=2.0,
- ssm_dt_rank="auto",
- ssm_act_layer=nn.SiLU,
- ssm_conv=3,
- ssm_conv_bias=True,
- ssm_drop_rate=0.0,
- ssm_init="v0",
- forward_type="v2",
- # ===========================
- mlp_ratio=4.0,
- mlp_act_layer=nn.GELU,
- mlp_drop_rate=0.0,
- gmlp=False,
- # ===========================
- _SS2D=SS2D,
- **kwargs,
- ):
- # if channel first, then Norm and Output are both channel_first
- depth = len(drop_path)
- blocks = []
- for d in range(depth):
- blocks.append(VSSBlock(
- hidden_dim=dim,
- drop_path=drop_path[d],
- norm_layer=norm_layer,
- channel_first=channel_first,
- ssm_d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- ssm_dt_rank=ssm_dt_rank,
- ssm_act_layer=ssm_act_layer,
- ssm_conv=ssm_conv,
- ssm_conv_bias=ssm_conv_bias,
- ssm_drop_rate=ssm_drop_rate,
- ssm_init=ssm_init,
- forward_type=forward_type,
- mlp_ratio=mlp_ratio,
- mlp_act_layer=mlp_act_layer,
- mlp_drop_rate=mlp_drop_rate,
- gmlp=gmlp,
- use_checkpoint=use_checkpoint,
- _SS2D=_SS2D,
- ))
-
- return nn.Sequential(OrderedDict(
- blocks=nn.Sequential(*blocks,),
- downsample=downsample,
- ))
- def forward(self, x: torch.Tensor):
- x = self.patch_embed(x)
- if self.pos_embed is not None:
- pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
- x = x + pos_embed
- for layer in self.layers:
- x = layer(x)
- x = self.classifier(x)
- return x
- def flops(self, shape=(3, 224, 224), verbose=True):
- # shape = self.__input_shape__[1:]
- supported_ops={
- "aten::silu": None, # as relu is in _IGNORED_OPS
- "aten::neg": None, # as relu is in _IGNORED_OPS
- "aten::exp": None, # as relu is in _IGNORED_OPS
- "aten::flip": None, # as permute is in _IGNORED_OPS
- # "prim::PythonOp.CrossScan": None,
- # "prim::PythonOp.CrossMerge": None,
- "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
- }
- model = copy.deepcopy(self)
- model.cuda().eval()
- input = torch.randn((1, *shape), device=next(model.parameters()).device)
- params = parameter_count(model)[""]
- Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
- del model, input
- return sum(Gflops.values()) * 1e9
- return f"params {params} GFLOPs {sum(Gflops.values())}"
- # used to load ckpt from previous training code
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
- def check_name(src, state_dict: dict = state_dict, strict=False):
- if strict:
- if prefix + src in list(state_dict.keys()):
- return True
- else:
- key = prefix + src
- for k in list(state_dict.keys()):
- if k.startswith(key):
- return True
- return False
- def change_name(src, dst, state_dict: dict = state_dict, strict=False):
- if strict:
- if prefix + src in list(state_dict.keys()):
- state_dict[prefix + dst] = state_dict[prefix + src]
- state_dict.pop(prefix + src)
- else:
- key = prefix + src
- for k in list(state_dict.keys()):
- if k.startswith(key):
- new_k = prefix + dst + k[len(key):]
- state_dict[new_k] = state_dict[k]
- state_dict.pop(k)
- if check_name("pos_embed", strict=True):
- srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
- state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
- change_name("patch_embed.proj", "patch_embed.0")
- change_name("patch_embed.norm", "patch_embed.2")
- for i in range(100):
- for j in range(100):
- change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
- change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
- change_name("norm", "classifier.norm")
- change_name("head", "classifier.head")
- return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- # compatible with openmmlab
- class Backbone_VSSM(VSSM):
- def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
- kwargs.update(norm_layer=norm_layer)
- super().__init__(**kwargs)
- self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
- _NORMLAYERS = dict(
- ln=nn.LayerNorm,
- ln2d=LayerNorm2d,
- bn=nn.BatchNorm2d,
- )
- norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
-
- self.out_indices = out_indices
- for i in out_indices:
- layer = norm_layer(self.dims[i])
- layer_name = f'outnorm{i}'
- self.add_module(layer_name, layer)
- del self.classifier
- self.load_pretrained(pretrained)
- def load_pretrained(self, ckpt=None, key="model"):
- if ckpt is None:
- return
-
- try:
- _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
- print(f"Successfully load ckpt {ckpt}")
- incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
- print(incompatibleKeys)
- except Exception as e:
- print(f"Failed loading checkpoint form {ckpt}: {e}")
- def forward(self, x):
- def layer_forward(l, x):
- x = l.blocks(x)
- y = l.downsample(x)
- return x, y
- x = self.patch_embed(x)
- outs = []
- for i, layer in enumerate(self.layers):
- o, x = layer_forward(layer, x) # (B, H, W, C)
- if i in self.out_indices:
- norm_layer = getattr(self, f'outnorm{i}')
- out = norm_layer(o)
- if not self.channel_first:
- out = out.permute(0, 3, 1, 2)
- outs.append(out.contiguous())
- if len(self.out_indices) == 0:
- return x
-
- return outs
- # =====================================================
- def vanilla_vmamba_tiny():
- return VSSM(
- depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vanilla_vmamba_small():
- return VSSM(
- depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vanilla_vmamba_base():
- return VSSM(
- depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # =====================================================
- def vmamba_tiny_s2l5(channel_first=True):
- return VSSM(
- depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_s2l15(channel_first=True):
- return VSSM(
- depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_s2l15(channel_first=True):
- return VSSM(
- depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # =====================================================
- def vmamba_tiny_s1l8(channel_first=True):
- return VSSM(
- depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_s1l20(channel_first=True):
- return VSSM(
- depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_s1l20(channel_first=True):
- return VSSM(
- depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # mamba2 support =====================================================
- # FLOPS count do not work now for mamba2!
- def vmamba_tiny_m2():
- return VSSM(
- depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_m2():
- return VSSM(
- depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_m2():
- return VSSM(
- depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- if __name__ == "__main__":
- model_ref = vmamba_tiny_s1l8()
- model = VSSM(
- depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- print(parameter_count(model)[""])
- print(model.flops()) # wrong
- model.cuda().train()
- model_ref.cuda().train()
- def bench(model):
- import time
- inp = torch.randn((128, 3, 224, 224)).cuda()
- for _ in range(30):
- model(inp)
- torch.cuda.synchronize()
- tim = time.time()
- for _ in range(30):
- model(inp)
- torch.cuda.synchronize()
- tim1 = time.time() - tim
- for _ in range(30):
- model(inp).sum().backward()
- torch.cuda.synchronize()
- tim = time.time()
- for _ in range(30):
- model(inp).sum().backward()
- torch.cuda.synchronize()
- tim2 = time.time() - tim
- return tim1 / 30, tim2 / 30
-
- print(bench(model_ref))
- print(bench(model))
- breakpoint()
|