| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849 |
- import os
- import time
- import math
- import copy
- from functools import partial
- from typing import Optional, Callable, Any
- from collections import OrderedDict
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from timm.layers import DropPath, trunc_normal_
- try:
- from .csm_triton import cross_scan_fn, cross_merge_fn
- except:
- from csm_triton import cross_scan_fn, cross_merge_fn
- try:
- from .csms6s import selective_scan_fn, selective_scan_flop_jit
- except:
- from csms6s import selective_scan_fn, selective_scan_flop_jit
- # FLOPs counter not prepared for mamba2.
- # Keep this dependency optional because SS2D(v2/v3) does not require it.
- try:
- from .mamba2.ssd_minimal import selective_scan_chunk_fn
- except Exception:
- try:
- from mamba2.ssd_minimal import selective_scan_chunk_fn
- except Exception:
- selective_scan_chunk_fn = None
- # =====================================================
- # we have this class as linear and conv init differ from each other
- # this function enable loading from both conv2d or linear
- class Linear2d(nn.Linear):
- def forward(self, x: torch.Tensor):
- # B, C, H, W = x.shape
- return F.conv2d(x, self.weight[:, :, None, None], self.bias)
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
- state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape)
- return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- class LayerNorm2d(nn.LayerNorm):
- def forward(self, x: torch.Tensor):
- x = x.permute(0, 2, 3, 1)
- x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
- x = x.permute(0, 3, 1, 2)
- return x
- class PatchMerging2D(nn.Module):
- def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False):
- super().__init__()
- self.dim = dim
- Linear = Linear2d if channel_first else nn.Linear
- self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last
- self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
- self.norm = norm_layer(4 * dim)
- @staticmethod
- def _patch_merging_pad_channel_last(x: torch.Tensor):
- H, W, _ = x.shape[-3:]
- if (W % 2 != 0) or (H % 2 != 0):
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
- x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
- x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
- x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
- return x
- @staticmethod
- def _patch_merging_pad_channel_first(x: torch.Tensor):
- H, W = x.shape[-2:]
- if (W % 2 != 0) or (H % 2 != 0):
- x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
- x0 = x[..., 0::2, 0::2] # ... H/2 W/2
- x1 = x[..., 1::2, 0::2] # ... H/2 W/2
- x2 = x[..., 0::2, 1::2] # ... H/2 W/2
- x3 = x[..., 1::2, 1::2] # ... H/2 W/2
- x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C
- return x
- def forward(self, x):
- x = self._patch_merging_pad(x)
- x = self.norm(x)
- x = self.reduction(x)
- return x
- class Permute(nn.Module):
- def __init__(self, *args):
- super().__init__()
- self.args = args
- def forward(self, x: torch.Tensor):
- return x.permute(*self.args)
- class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- Linear = Linear2d if channels_first else nn.Linear
- self.fc1 = Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class gMlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
- super().__init__()
- self.channel_first = channels_first
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- Linear = Linear2d if channels_first else nn.Linear
- self.fc1 = Linear(in_features, 2 * hidden_features)
- self.act = act_layer()
- self.fc2 = Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x: torch.Tensor):
- x = self.fc1(x)
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1))
- x = self.fc2(x * self.act(z))
- x = self.drop(x)
- return x
- class SoftmaxSpatial(nn.Softmax):
- def forward(self, x: torch.Tensor):
- if self.dim == -1:
- B, C, H, W = x.shape
- return super().forward(x.view(B, C, -1)).view(B, C, H, W)
- elif self.dim == 1:
- B, H, W, C = x.shape
- return super().forward(x.view(B, -1, C)).view(B, H, W, C)
- else:
- raise NotImplementedError
- # =====================================================
- class mamba_init:
- @staticmethod
- def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
- dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
- # Initialize special dt projection to preserve variance at initialization
- dt_init_std = dt_rank**-0.5 * dt_scale
- if dt_init == "constant":
- nn.init.constant_(dt_proj.weight, dt_init_std)
- elif dt_init == "random":
- nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
- else:
- raise NotImplementedError
- # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
- dt = torch.exp(
- torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
- + math.log(dt_min)
- ).clamp(min=dt_init_floor)
- # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
- inv_dt = dt + torch.log(-torch.expm1(-dt))
- with torch.no_grad():
- dt_proj.bias.copy_(inv_dt)
- # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
- # dt_proj.bias._no_reinit = True
-
- return dt_proj
- @staticmethod
- def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
- # S4D real initialization
- A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
- A_log = torch.log(A) # Keep A_log in fp32
- if copies > 0:
- A_log = A_log[None].repeat(copies, 1, 1).contiguous()
- if merge:
- A_log = A_log.flatten(0, 1)
- A_log = nn.Parameter(A_log)
- A_log._no_weight_decay = True
- return A_log
- @staticmethod
- def D_init(d_inner, copies=-1, device=None, merge=True):
- # D "skip" parameter
- D = torch.ones(d_inner, device=device)
- if copies > 0:
- D = D[None].repeat(copies, 1).contiguous()
- if merge:
- D = D.flatten(0, 1)
- D = nn.Parameter(D) # Keep in fp32
- D._no_weight_decay = True
- return D
- @classmethod
- def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
- # dt proj ============================
- dt_projs = [
- cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
- for _ in range(k_group)
- ]
- dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
- dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
- del dt_projs
-
- # A, D =======================================
- A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
- Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
- return A_logs, Ds, dt_projs_weight, dt_projs_bias
- # support: v0, v0seq
- class SS2Dv0:
- def __initv0__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- # ======================
- dropout=0.0,
- # ======================
- seq=False,
- force_fp32=True,
- **kwargs,
- ):
- if "channel_first" in kwargs:
- assert not kwargs["channel_first"]
- act_layer = nn.SiLU
- dt_min = 0.001
- dt_max = 0.1
- dt_init = "random"
- dt_scale = 1.0
- dt_init_floor = 1e-4
- bias = False
- conv_bias = True
- d_conv = 3
- k_group = 4
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- self.forward = self.forwardv0
- if seq:
- self.forward = partial(self.forwardv0, seq=True)
- if not force_fp32:
- self.forward = partial(self.forwardv0, force_fp32=False)
- # in proj ============================
- self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
- self.act: nn.Module = act_layer()
- self.conv2d = nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- )
- # x proj ============================
- self.x_proj = [
- nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
- for _ in range(k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
- # dt proj, A, D ============================
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
- )
- # out proj =======================================
- self.out_norm = nn.LayerNorm(d_inner)
- self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
- x = self.in_proj(x)
- x, z = x.chunk(2, dim=-1) # (b, h, w, d)
- z = self.act(z)
- x = x.permute(0, 3, 1, 2).contiguous()
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- selective_scan = partial(selective_scan_fn, backend="mamba")
-
- B, D, H, W = x.shape
- D, N = self.A_logs.shape
- K, D, R = self.dt_projs_weight.shape
- L = H * W
- x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
- xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
- if hasattr(self, "x_proj_bias"):
- x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
- xs = xs.view(B, -1, L) # (b, k * d, l)
- dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
- Bs = Bs.contiguous() # (b, k, d_state, l)
- Cs = Cs.contiguous() # (b, k, d_state, l)
-
- As = -self.A_logs.float().exp() # (k * d, d_state)
- Ds = self.Ds.float() # (k * d)
- dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
- # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
- # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
-
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- if seq:
- out_y = []
- for i in range(4):
- yi = selective_scan(
- xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
- As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
- delta_bias=dt_projs_bias.view(K, -1)[i],
- delta_softplus=True,
- ).view(B, -1, L)
- out_y.append(yi)
- out_y = torch.stack(out_y, dim=1)
- else:
- out_y = selective_scan(
- xs, dts,
- As, Bs, Cs, Ds,
- delta_bias=dt_projs_bias,
- delta_softplus=True,
- ).view(B, K, -1, L)
- assert out_y.dtype == torch.float
- inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
- wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
- y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
-
- y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
- y = self.out_norm(y).view(B, H, W, -1)
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- # support: v01-v05; v051d,v052d,v052dc;
- # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
- # history support: v2,v3;v31d,v32d,v32dc;
- class SS2Dv2:
- def __initv2__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.SiLU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- self.k_group = 4
- self.d_model = int(d_model)
- self.d_state = int(d_state)
- self.d_inner = int(ssm_ratio * d_model)
- self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
- self.channel_first = channel_first
- self.with_dconv = d_conv > 1
- Linear = Linear2d if channel_first else nn.Linear
- self.forward = self.forwardv2
- # tags for forward_type ==============================
- checkpostfix = self.checkpostfix
- self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
- self.oact, forward_type = checkpostfix("_oact", forward_type)
- self.disable_z, forward_type = checkpostfix("_noz", forward_type)
- self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
- self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
- # forward_type debug =======================================
- FORWARD_TYPES = dict(
- v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
- v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
- v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
- v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
- v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
- # ===============================
- v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
- v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
- v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
- v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
- # ===============================
- v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
- v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
- )
- self.forward_core = FORWARD_TYPES.get(forward_type, None)
- # in proj =======================================
- d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
- self.in_proj = Linear(self.d_model, d_proj, bias=bias)
- self.act: nn.Module = act_layer()
-
- # conv =======================================
- if self.with_dconv:
- self.conv2d = nn.Conv2d(
- in_channels=self.d_inner,
- out_channels=self.d_inner,
- groups=self.d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- )
- # x proj ============================
- self.x_proj = [
- nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
- for _ in range(self.k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
-
- # out proj =======================================
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- self.out_proj = Linear(self.d_inner, self.d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- if initialize in ["v0"]:
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
- )
- elif initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
- self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
- self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
- self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
- def forward_corev2(
- self,
- x: torch.Tensor=None,
- # ==============================
- force_fp32=False, # True: input fp32
- # ==============================
- ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
- no_einsum=False, # replace einsum with linear or conv1d to raise throughput
- # ==============================
- selective_scan_backend = None,
- # ==============================
- scan_mode = "cross2d",
- scan_force_torch = False,
- # ==============================
- **kwargs,
- ):
- assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
- _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
- assert isinstance(_scan_mode, int)
- delta_softplus = True
- out_norm = self.out_norm
- channel_first = self.channel_first
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- B, D, H, W = x.shape
- N = self.d_state
- K, D, R = self.k_group, self.d_inner, self.dt_rank
- L = H * W
- def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
- return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
-
- if _scan_mode == -1:
- x_proj_bias = getattr(self, "x_proj_bias", None)
- def scan_rowcol(
- x: torch.Tensor,
- proj_weight: torch.Tensor,
- proj_bias: torch.Tensor,
- dt_weight: torch.Tensor,
- dt_bias: torch.Tensor, # (2*c)
- _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state)
- _Ds: torch.Tensor,
- width = True,
- ):
- # x: (B, D, H, W)
- # proj_weight: (2 * D, (R+N+N))
- XB, XD, XH, XW = x.shape
- if width:
- _B, _D, _L = XB * XH, XD, XW
- xs = x.permute(0, 2, 1, 3).contiguous()
- else:
- _B, _D, _L = XB * XW, XD, XH
- xs = x.permute(0, 3, 1, 2).contiguous()
- xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W)
- if no_einsum:
- x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2)
- dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2)
- dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2)
- else:
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight)
- xs = xs.view(_B, -1, _L)
- dts = dts.contiguous().view(_B, -1, _L)
- As = _As.view(-1, N).to(torch.float)
- Bs = Bs.contiguous().view(_B, 2, N, _L)
- Cs = Cs.contiguous().view(_B, 2, N, _L)
- Ds = _Ds.view(-1)
- delta_bias = dt_bias.view(-1).to(torch.float)
- if force_fp32:
- xs = xs.to(torch.float)
- dts = dts.to(xs.dtype)
- Bs = Bs.to(xs.dtype)
- Cs = Cs.to(xs.dtype)
- ys: torch.Tensor = selective_scan(
- xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(_B, 2, -1, _L)
- return ys
-
- As = -self.A_logs.to(torch.float).exp().view(4, -1, N)
- x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
- y_row = scan_rowcol(
- x,
- proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(),
- proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None),
- dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(),
- dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None),
- _As = As[:2].contiguous().view(-1, N),
- _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1),
- width=True,
- ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W)
- y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
- y_col = scan_rowcol(
- y_row,
- proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype),
- proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None),
- dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype),
- dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None),
- _As = As[2:].contiguous().view(-1, N),
- _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1),
- width=False,
- ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1)
- y = y_col
- else:
- x_proj_bias = getattr(self, "x_proj_bias", None)
- xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
- if no_einsum:
- x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
- dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
- if hasattr(self, "dt_projs_weight"):
- dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K)
- else:
- x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
- if hasattr(self, "dt_projs_weight"):
- dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
- xs = xs.view(B, -1, L)
- dts = dts.contiguous().view(B, -1, L)
- As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
- Ds = self.Ds.to(torch.float) # (K * c)
- Bs = Bs.contiguous().view(B, K, N, L)
- Cs = Cs.contiguous().view(B, K, N, L)
- delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- ys: torch.Tensor = selective_scan(
- xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(B, K, -1, H, W)
-
- y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
- us=xs, dts=dts, delta_bias=delta_bias,
- ys=ys, y=y, H=H, W=W,
- ))
- y = y.view(B, -1, H, W)
- if not channel_first:
- y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C)
- y = out_norm(y)
- return y.to(x.dtype)
- def forwardv2(self, x: torch.Tensor, **kwargs):
- x = self.in_proj(x)
- if not self.disable_z:
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
- if not self.disable_z_act:
- z = self.act(z)
- if not self.channel_first:
- x = x.permute(0, 3, 1, 2).contiguous()
- if self.with_dconv:
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- y = self.forward_core(x)
- y = self.out_act(y)
- if not self.disable_z:
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- @staticmethod
- def get_outnorm(forward_type="", d_inner=192, channel_first=True):
- def checkpostfix(tag, value):
- ret = value[-len(tag):] == tag
- if ret:
- value = value[:-len(tag)]
- return ret, value
- LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm
- out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
- out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
- out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
- out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
- out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
- out_norm = nn.Identity()
- if out_norm_none:
- out_norm = nn.Identity()
- elif out_norm_cnorm:
- out_norm = nn.Sequential(
- LayerNorm(d_inner),
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- )
- elif out_norm_dwconv3:
- out_norm = nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- )
- elif out_norm_softmax:
- out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
- elif out_norm_sigmoid:
- out_norm = nn.Sigmoid()
- else:
- out_norm = LayerNorm(d_inner)
- return out_norm, forward_type
- @staticmethod
- def checkpostfix(tag, value):
- ret = value[-len(tag):] == tag
- if ret:
- value = value[:-len(tag)]
- return ret, value
- # support: xv1a,xv2a,xv3a;
- # postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone;
- class SS2Dv3:
- def __initxv__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- self.channel_first = channel_first
- self.d_state = d_state
- self.dt_rank = dt_rank
- self.d_inner = d_inner
- k_group = 4
- self.with_dconv = d_conv > 1
- Linear = Linear2d if channel_first else nn.Linear
- self.forward = self.forwardxv
- # tags for forward_type ==============================
- checkpostfix = SS2Dv2.checkpostfix
- self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first)
- self.omul, forward_type = checkpostfix("_mul", forward_type)
- self.oact, forward_type = checkpostfix("_act", forward_type)
- self.f_omul = nn.Identity() if self.omul else None
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- mode = forward_type[:4]
- assert mode in ["xv1a", "xv2a", "xv3a"]
- self.forward = partial(self.forwardxv, mode=mode)
- self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode]
- d_inner_all = d_inner + self.dts_dim + 8 * d_state
- self.in_proj = Linear(d_model, d_inner_all, bias=bias)
-
- # conv =======================================
- self.cpos = False
- self.iconv = False
- self.oconv = False
- self.oconv2 = False
- if self.with_dconv:
- cact, forward_type = checkpostfix("_ca", forward_type)
- cact1, forward_type = checkpostfix("_ca1", forward_type)
- self.cact = nn.SiLU() if cact else nn.Identity()
- self.cact = nn.GELU() if cact1 else self.cact
-
- self.oconv2, forward_type = checkpostfix("_ocov2", forward_type)
- self.oconv, forward_type = checkpostfix("_ocov", forward_type)
- self.cpos, forward_type = checkpostfix("_cpos", forward_type)
- self.iconv = (not self.oconv) and (not self.oconv2)
- if self.iconv:
- self.conv2d = nn.Conv2d(
- in_channels=d_model,
- out_channels=d_model,
- groups=d_model,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- if self.oconv:
- self.oconv2d = nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- if self.oconv2:
- self.conv2d = nn.Conv2d(
- in_channels=d_inner_all,
- out_channels=d_inner_all,
- groups=d_inner_all,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- )
- # out proj =======================================
- self.out_proj = Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
- if initialize in ["v0"]:
- self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
- d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
- )
- elif initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
- self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
- self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
- self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank)))
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner)))
- if forward_type.startswith("xv2"):
- del self.dt_projs_weight
- self.dt_projs_weight = None
- def forwardxv(self, x: torch.Tensor, **kwargs):
- B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3])
- L = H * W
- force_fp32 = False
- delta_softplus = True
- out_norm = self.out_norm
- to_dtype = True
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):
- return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None)
- if self.iconv:
- x = self.cact(self.conv2d(x)) # (b, d, h, w)
- elif self.cpos:
- x = x + self.conv2d(x) # (b, d, h, w)
- x = self.in_proj(x)
-
- if self.oconv2:
- x = self.conv2d(x) # (b, d, h, w)
- us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1))
- _us = us
- # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1)
- # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W)
- us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L)
- Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
- Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
- dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1)
- if self.dts_dim == self.dt_rank:
- dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
- elif self.dts_dim == 4 * self.dt_rank:
- dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
- As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
- Ds = self.Ds.to(torch.float) # (K * c)
- delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c)
- if force_fp32:
- us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs)
- ys: torch.Tensor = selective_scan(
- us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
- ).view(B, 4, -1, H, W)
- y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True)
- y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1)
- y = out_norm(y)
-
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
- us=us, dts=dts, delta_bias=delta_bias,
- ys=ys, y=y,
- ))
- y = (y.to(x.dtype) if to_dtype else y)
-
- y = self.out_act(y)
-
- if self.omul:
- y = y * _us
- if self.oconv:
- y = y + self.cact(self.oconv2d(_us))
- out = self.dropout(self.out_proj(y))
- return out
- # mamba2 support ================================
- class SS2Dm0:
- def __initm0__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16, # now with mamba2, dstate should be bigger...
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.GELU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v2",
- # ======================
- forward_type="m0",
- # ======================
- with_initial_state=False,
- # ======================
- **kwargs,
- ):
- factory_kwargs = {"device": None, "dtype": None}
- super().__init__()
- d_inner = int(ssm_ratio * d_model)
- dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
- assert d_inner % dt_rank == 0
- self.with_dconv = d_conv > 1
- Linear = nn.Linear
- self.forward = self.forwardm0
- # tags for forward_type ==============================
- checkpostfix = SS2Dv2.checkpostfix
- self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
- self.oact, forward_type = checkpostfix("_oact", forward_type)
- self.disable_z, forward_type = checkpostfix("_noz", forward_type)
- self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
- self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False)
- # forward_type debug =======================================
- FORWARD_TYPES = dict(
- m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state),
- )
- self.forward_core = FORWARD_TYPES.get(forward_type, None)
- k_group = 4
- # in proj =======================================
- d_proj = d_inner if self.disable_z else (d_inner * 2)
- self.in_proj = Linear(d_model, d_proj, bias=bias)
- self.act: nn.Module = act_layer()
-
- # conv =======================================
- if self.with_dconv:
- self.conv2d = nn.Sequential(
- Permute(0, 3, 1, 2),
- nn.Conv2d(
- in_channels=d_inner,
- out_channels=d_inner,
- groups=d_inner,
- bias=conv_bias,
- kernel_size=d_conv,
- padding=(d_conv - 1) // 2,
- **factory_kwargs,
- ),
- Permute(0, 2, 3, 1),
- )
-
- # x proj ============================
- self.x_proj = [
- nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
- for _ in range(k_group)
- ]
- self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
- del self.x_proj
-
- # out proj =======================================
- self.out_act = nn.GELU() if self.oact else nn.Identity()
- self.out_proj = Linear(d_inner, d_model, bias=bias)
- self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
- if initialize in ["v1"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
- self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430
- elif initialize in ["v2"]:
- # simple init dt_projs, A_logs, Ds
- self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
- self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
- self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank)))
- # init state ============================
- self.initial_state = None
- if with_initial_state:
- self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False)
- def forward_corem0(
- self,
- x: torch.Tensor=None,
- # ==============================
- force_fp32=False, # True: input fp32
- chunk_size = 64,
- dstate = 64,
- # ==============================
- selective_scan_backend = None,
- scan_mode = "cross2d",
- scan_force_torch = False,
- # ==============================
- **kwargs,
- ):
- assert scan_mode in ["unidi", "bidi", "cross2d"]
- assert selective_scan_backend in [None, "triton", "torch"]
- x_proj_bias = getattr(self, "x_proj_bias", None)
- to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
- N = dstate
- B, H, W, RD = x.shape
- K, R = self.A_logs.shape
- K, R, D = self.Ds.shape
- assert RD == R * D
- L = H * W
- KR = K * R
- _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode]
- initial_state = None
- if self.initial_state is not None:
- assert self.initial_state.shape[-1] == dstate
- initial_state = self.initial_state.detach().repeat(B, 1, 1, 1)
- xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D)
- x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight)
- if x_proj_bias is not None:
- x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1)
- dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3)
- xs = xs.contiguous().view(B, L, KR, D)
- dts = dts.contiguous().view(B, L, KR)
- Bs = Bs.contiguous().view(B, L, K, N)
- Cs = Cs.contiguous().view(B, L, K, N)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- As = -self.A_logs.to(torch.float).exp().view(KR)
- Ds = self.Ds.to(torch.float).view(KR, D)
- dt_bias = self.dt_projs_bias.view(KR)
- if force_fp32:
- xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
- ys, final_state = selective_scan_chunk_fn(
- xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias,
- initial_states=initial_state, dt_softplus=True, return_final_states=True,
- backend=selective_scan_backend,
- )
- y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch)
- if getattr(self, "__DEBUG__", False):
- setattr(self, "__data__", dict(
- A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds,
- us=xs, dts=dts, delta_bias=self.dt_projs_bias,
- initial_state=self.initial_state, final_satte=final_state,
- ys=ys, y=y, H=H, W=W,
- ))
- if self.initial_state is not None:
- self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False)
- y = self.out_norm(y.view(B, H, W, -1))
- return y.to(x.dtype)
- def forwardm0(self, x: torch.Tensor, **kwargs):
- x = self.in_proj(x)
- if not self.disable_z:
- x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
- if not self.disable_z_act:
- z = self.act(z)
- if self.with_dconv:
- x = self.conv2d(x) # (b, d, h, w)
- x = self.act(x)
- y = self.forward_core(x)
- y = self.out_act(y)
- if not self.disable_z:
- y = y * z
- out = self.dropout(self.out_proj(y))
- return out
- class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0):
- def __init__(
- self,
- # basic dims ===========
- d_model=96,
- d_state=16,
- ssm_ratio=2.0,
- dt_rank="auto",
- act_layer=nn.SiLU,
- # dwconv ===============
- d_conv=3, # < 2 means no conv
- conv_bias=True,
- # ======================
- dropout=0.0,
- bias=False,
- # dt init ==============
- dt_min=0.001,
- dt_max=0.1,
- dt_init="random",
- dt_scale=1.0,
- dt_init_floor=1e-4,
- initialize="v0",
- # ======================
- forward_type="v2",
- channel_first=False,
- # ======================
- **kwargs,
- ):
- nn.Module.__init__(self)
- kwargs.update(
- d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
- act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
- dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
- initialize=initialize, forward_type=forward_type, channel_first=channel_first,
- )
- if forward_type in ["v0", "v0seq"]:
- self.__initv0__(seq=("seq" in forward_type), **kwargs)
- elif forward_type.startswith("xv"):
- self.__initxv__(**kwargs)
- elif forward_type.startswith("m"):
- self.__initm0__(**kwargs)
- else:
- self.__initv2__(**kwargs)
- # =====================================================
- class VSSBlock(nn.Module):
- def __init__(
- self,
- hidden_dim: int = 0,
- drop_path: float = 0,
- norm_layer: nn.Module = nn.LayerNorm,
- channel_first=False,
- # =============================
- ssm_d_state: int = 16,
- ssm_ratio=2.0,
- ssm_dt_rank: Any = "auto",
- ssm_act_layer=nn.SiLU,
- ssm_conv: int = 3,
- ssm_conv_bias=True,
- ssm_drop_rate: float = 0,
- ssm_init="v0",
- forward_type="v2",
- # =============================
- mlp_ratio=4.0,
- mlp_act_layer=nn.GELU,
- mlp_drop_rate: float = 0.0,
- gmlp=False,
- # =============================
- use_checkpoint: bool = False,
- post_norm: bool = False,
- # =============================
- _SS2D: type = SS2D,
- **kwargs,
- ):
- super().__init__()
- self.ssm_branch = ssm_ratio > 0
- self.mlp_branch = mlp_ratio > 0
- self.use_checkpoint = use_checkpoint
- self.post_norm = post_norm
- if self.ssm_branch:
- self.norm = norm_layer(hidden_dim)
- self.op = _SS2D(
- d_model=hidden_dim,
- d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- dt_rank=ssm_dt_rank,
- act_layer=ssm_act_layer,
- # ==========================
- d_conv=ssm_conv,
- conv_bias=ssm_conv_bias,
- # ==========================
- dropout=ssm_drop_rate,
- # bias=False,
- # ==========================
- # dt_min=0.001,
- # dt_max=0.1,
- # dt_init="random",
- # dt_scale="random",
- # dt_init_floor=1e-4,
- initialize=ssm_init,
- # ==========================
- forward_type=forward_type,
- channel_first=channel_first,
- )
-
- self.drop_path = DropPath(drop_path)
-
- if self.mlp_branch:
- _MLP = Mlp if not gmlp else gMlp
- self.norm2 = norm_layer(hidden_dim)
- mlp_hidden_dim = int(hidden_dim * mlp_ratio)
- self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first)
- def _forward(self, input: torch.Tensor):
- x = input
- if self.ssm_branch:
- if self.post_norm:
- x = x + self.drop_path(self.norm(self.op(x)))
- else:
- x = x + self.drop_path(self.op(self.norm(x)))
- if self.mlp_branch:
- if self.post_norm:
- x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
- else:
- x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
- return x
- def forward(self, input: torch.Tensor):
- if self.use_checkpoint:
- return checkpoint.checkpoint(self._forward, input)
- else:
- return self._forward(input)
- class VSSM(nn.Module):
- def __init__(
- self,
- patch_size=4,
- in_chans=3,
- num_classes=1000,
- depths=[2, 2, 9, 2],
- dims=[96, 192, 384, 768],
- # =========================
- ssm_d_state=16,
- ssm_ratio=2.0,
- ssm_dt_rank="auto",
- ssm_act_layer="silu",
- ssm_conv=3,
- ssm_conv_bias=True,
- ssm_drop_rate=0.0,
- ssm_init="v0",
- forward_type="v2",
- # =========================
- mlp_ratio=4.0,
- mlp_act_layer="gelu",
- mlp_drop_rate=0.0,
- gmlp=False,
- # =========================
- drop_path_rate=0.1,
- patch_norm=True,
- norm_layer="LN", # "BN", "LN2D"
- downsample_version: str = "v2", # "v1", "v2", "v3"
- patchembed_version: str = "v1", # "v1", "v2"
- use_checkpoint=False,
- # =========================
- posembed=False,
- imgsize=224,
- _SS2D=SS2D,
- # =========================
- **kwargs,
- ):
- super().__init__()
- self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
- self.num_classes = num_classes
- self.num_layers = len(depths)
- if isinstance(dims, int):
- dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
- self.num_features = dims[-1]
- self.dims = dims
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- _NORMLAYERS = dict(
- ln=nn.LayerNorm,
- ln2d=LayerNorm2d,
- bn=nn.BatchNorm2d,
- )
- _ACTLAYERS = dict(
- silu=nn.SiLU,
- gelu=nn.GELU,
- relu=nn.ReLU,
- sigmoid=nn.Sigmoid,
- )
- norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
- ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
- mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
- self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
- _make_patch_embed = dict(
- v1=self._make_patch_embed,
- v2=self._make_patch_embed_v2,
- ).get(patchembed_version, None)
- self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first)
- _make_downsample = dict(
- v1=PatchMerging2D,
- v2=self._make_downsample,
- v3=self._make_downsample_v3,
- none=(lambda *_, **_k: None),
- ).get(downsample_version, None)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- downsample = _make_downsample(
- self.dims[i_layer],
- self.dims[i_layer + 1],
- norm_layer=norm_layer,
- channel_first=self.channel_first,
- ) if (i_layer < self.num_layers - 1) else nn.Identity()
- self.layers.append(self._make_layer(
- dim = self.dims[i_layer],
- drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
- use_checkpoint=use_checkpoint,
- norm_layer=norm_layer,
- downsample=downsample,
- channel_first=self.channel_first,
- # =================
- ssm_d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- ssm_dt_rank=ssm_dt_rank,
- ssm_act_layer=ssm_act_layer,
- ssm_conv=ssm_conv,
- ssm_conv_bias=ssm_conv_bias,
- ssm_drop_rate=ssm_drop_rate,
- ssm_init=ssm_init,
- forward_type=forward_type,
- # =================
- mlp_ratio=mlp_ratio,
- mlp_act_layer=mlp_act_layer,
- mlp_drop_rate=mlp_drop_rate,
- gmlp=gmlp,
- # =================
- _SS2D=_SS2D,
- ))
- self.classifier = nn.Sequential(OrderedDict(
- norm=norm_layer(self.num_features), # B,H,W,C
- permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
- avgpool=nn.AdaptiveAvgPool2d(1),
- flatten=nn.Flatten(1),
- head=nn.Linear(self.num_features, num_classes),
- ))
- self.apply(self._init_weights)
- @staticmethod
- def _pos_embed(embed_dims, patch_size, img_size):
- patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
- pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
- trunc_normal_(pos_embed, std=0.02)
- return pos_embed
- def _init_weights(self, m: nn.Module):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- # used in building optimizer
- @torch.jit.ignore
- def no_weight_decay(self):
- return {"pos_embed"}
- # used in building optimizer
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {}
- @staticmethod
- def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim) if patch_norm else nn.Identity()),
- )
- @staticmethod
- def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- stride = patch_size // 2
- kernel_size = stride + 1
- padding = 1
- return nn.Sequential(
- nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
- (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
- (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)),
- nn.GELU(),
- nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- (norm_layer(embed_dim) if patch_norm else nn.Identity()),
- )
-
- @staticmethod
- def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- norm_layer(out_dim),
- )
- @staticmethod
- def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
- # if channel first, then Norm and Output are both channel_first
- return nn.Sequential(
- (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
- nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
- (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
- norm_layer(out_dim),
- )
- @staticmethod
- def _make_layer(
- dim=96,
- drop_path=[0.1, 0.1],
- use_checkpoint=False,
- norm_layer=nn.LayerNorm,
- downsample=nn.Identity(),
- channel_first=False,
- # ===========================
- ssm_d_state=16,
- ssm_ratio=2.0,
- ssm_dt_rank="auto",
- ssm_act_layer=nn.SiLU,
- ssm_conv=3,
- ssm_conv_bias=True,
- ssm_drop_rate=0.0,
- ssm_init="v0",
- forward_type="v2",
- # ===========================
- mlp_ratio=4.0,
- mlp_act_layer=nn.GELU,
- mlp_drop_rate=0.0,
- gmlp=False,
- # ===========================
- _SS2D=SS2D,
- **kwargs,
- ):
- # if channel first, then Norm and Output are both channel_first
- depth = len(drop_path)
- blocks = []
- for d in range(depth):
- blocks.append(VSSBlock(
- hidden_dim=dim,
- drop_path=drop_path[d],
- norm_layer=norm_layer,
- channel_first=channel_first,
- ssm_d_state=ssm_d_state,
- ssm_ratio=ssm_ratio,
- ssm_dt_rank=ssm_dt_rank,
- ssm_act_layer=ssm_act_layer,
- ssm_conv=ssm_conv,
- ssm_conv_bias=ssm_conv_bias,
- ssm_drop_rate=ssm_drop_rate,
- ssm_init=ssm_init,
- forward_type=forward_type,
- mlp_ratio=mlp_ratio,
- mlp_act_layer=mlp_act_layer,
- mlp_drop_rate=mlp_drop_rate,
- gmlp=gmlp,
- use_checkpoint=use_checkpoint,
- _SS2D=_SS2D,
- ))
-
- return nn.Sequential(OrderedDict(
- blocks=nn.Sequential(*blocks,),
- downsample=downsample,
- ))
- def forward(self, x: torch.Tensor):
- x = self.patch_embed(x)
- if self.pos_embed is not None:
- pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
- x = x + pos_embed
- for layer in self.layers:
- x = layer(x)
- x = self.classifier(x)
- return x
- def flops(self, shape=(3, 224, 224), verbose=True):
- from fvcore.nn import flop_count, parameter_count
- # shape = self.__input_shape__[1:]
- supported_ops={
- "aten::silu": None, # as relu is in _IGNORED_OPS
- "aten::neg": None, # as relu is in _IGNORED_OPS
- "aten::exp": None, # as relu is in _IGNORED_OPS
- "aten::flip": None, # as permute is in _IGNORED_OPS
- # "prim::PythonOp.CrossScan": None,
- # "prim::PythonOp.CrossMerge": None,
- "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
- }
- model = copy.deepcopy(self)
- model.cuda().eval()
- input = torch.randn((1, *shape), device=next(model.parameters()).device)
- params = parameter_count(model)[""]
- Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
- del model, input
- return sum(Gflops.values()) * 1e9
- return f"params {params} GFLOPs {sum(Gflops.values())}"
- # used to load ckpt from previous training code
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
- def check_name(src, state_dict: dict = state_dict, strict=False):
- if strict:
- if prefix + src in list(state_dict.keys()):
- return True
- else:
- key = prefix + src
- for k in list(state_dict.keys()):
- if k.startswith(key):
- return True
- return False
- def change_name(src, dst, state_dict: dict = state_dict, strict=False):
- if strict:
- if prefix + src in list(state_dict.keys()):
- state_dict[prefix + dst] = state_dict[prefix + src]
- state_dict.pop(prefix + src)
- else:
- key = prefix + src
- for k in list(state_dict.keys()):
- if k.startswith(key):
- new_k = prefix + dst + k[len(key):]
- state_dict[new_k] = state_dict[k]
- state_dict.pop(k)
- if check_name("pos_embed", strict=True):
- srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
- state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
- change_name("patch_embed.proj", "patch_embed.0")
- change_name("patch_embed.norm", "patch_embed.2")
- for i in range(100):
- for j in range(100):
- change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
- change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
- change_name("norm", "classifier.norm")
- change_name("head", "classifier.head")
- return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- # compatible with openmmlab
- class Backbone_VSSM(VSSM):
- def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
- kwargs.update(norm_layer=norm_layer)
- super().__init__(**kwargs)
- self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
- _NORMLAYERS = dict(
- ln=nn.LayerNorm,
- ln2d=LayerNorm2d,
- bn=nn.BatchNorm2d,
- )
- norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
-
- self.out_indices = out_indices
- for i in out_indices:
- layer = norm_layer(self.dims[i])
- layer_name = f'outnorm{i}'
- self.add_module(layer_name, layer)
- del self.classifier
- self.load_pretrained(pretrained)
- def load_pretrained(self, ckpt=None, key="model"):
- if ckpt is None:
- return
-
- try:
- _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
- print(f"Successfully load ckpt {ckpt}")
- incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
- print(incompatibleKeys)
- except Exception as e:
- print(f"Failed loading checkpoint form {ckpt}: {e}")
- def forward(self, x):
- def layer_forward(l, x):
- x = l.blocks(x)
- y = l.downsample(x)
- return x, y
- x = self.patch_embed(x)
- outs = []
- for i, layer in enumerate(self.layers):
- o, x = layer_forward(layer, x) # (B, H, W, C)
- if i in self.out_indices:
- norm_layer = getattr(self, f'outnorm{i}')
- out = norm_layer(o)
- if not self.channel_first:
- out = out.permute(0, 3, 1, 2)
- outs.append(out.contiguous())
- if len(self.out_indices) == 0:
- return x
-
- return outs
- # =====================================================
- def vanilla_vmamba_tiny():
- return VSSM(
- depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vanilla_vmamba_small():
- return VSSM(
- depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vanilla_vmamba_base():
- return VSSM(
- depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v0",
- mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v1", patchembed_version="v1",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # =====================================================
- def vmamba_tiny_s2l5(channel_first=True):
- return VSSM(
- depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_s2l15(channel_first=True):
- return VSSM(
- depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_s2l15(channel_first=True):
- return VSSM(
- depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # =====================================================
- def vmamba_tiny_s1l8(channel_first=True):
- return VSSM(
- depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_s1l20(channel_first=True):
- return VSSM(
- depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_s1l20(channel_first=True):
- return VSSM(
- depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v0", forward_type="v05_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- # mamba2 support =====================================================
- # FLOPS count do not work now for mamba2!
- def vmamba_tiny_m2():
- return VSSM(
- depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_small_m2():
- return VSSM(
- depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- def vmamba_base_m2():
- return VSSM(
- depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- if __name__ == "__main__":
- from fvcore.nn import parameter_count
- model_ref = vmamba_tiny_s1l8()
- model = VSSM(
- depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
- patch_size=4, in_chans=3, num_classes=1000,
- ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
- ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
- ssm_init="v2", forward_type="m0_noz",
- mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
- patch_norm=True, norm_layer="ln",
- downsample_version="v3", patchembed_version="v2",
- use_checkpoint=False, posembed=False, imgsize=224,
- )
- print(parameter_count(model)[""])
- print(model.flops()) # wrong
- model.cuda().train()
- model_ref.cuda().train()
- def bench(model):
- import time
- inp = torch.randn((128, 3, 224, 224)).cuda()
- for _ in range(30):
- model(inp)
- torch.cuda.synchronize()
- tim = time.time()
- for _ in range(30):
- model(inp)
- torch.cuda.synchronize()
- tim1 = time.time() - tim
- for _ in range(30):
- model(inp).sum().backward()
- torch.cuda.synchronize()
- tim = time.time()
- for _ in range(30):
- model(inp).sum().backward()
- torch.cuda.synchronize()
- tim2 = time.time() - tim
- return tim1 / 30, tim2 / 30
-
- print(bench(model_ref))
- print(bench(model))
- breakpoint()
|