vmamba.py 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846
  1. import os
  2. import time
  3. import math
  4. import copy
  5. from functools import partial
  6. from typing import Optional, Callable, Any
  7. from collections import OrderedDict
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torch.utils.checkpoint as checkpoint
  12. from timm.layers import DropPath, trunc_normal_
  13. from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
  14. try:
  15. from .csm_triton import cross_scan_fn, cross_merge_fn
  16. except:
  17. from csm_triton import cross_scan_fn, cross_merge_fn
  18. try:
  19. from .csms6s import selective_scan_fn, selective_scan_flop_jit
  20. except:
  21. from csms6s import selective_scan_fn, selective_scan_flop_jit
  22. # FLOPs counter not prepared for mamba2.
  23. # Keep this dependency optional because SS2D(v2/v3) does not require it.
  24. try:
  25. from .mamba2.ssd_minimal import selective_scan_chunk_fn
  26. except Exception:
  27. try:
  28. from mamba2.ssd_minimal import selective_scan_chunk_fn
  29. except Exception:
  30. selective_scan_chunk_fn = None
  31. # =====================================================
  32. # we have this class as linear and conv init differ from each other
  33. # this function enable loading from both conv2d or linear
  34. class Linear2d(nn.Linear):
  35. def forward(self, x: torch.Tensor):
  36. # B, C, H, W = x.shape
  37. return F.conv2d(x, self.weight[:, :, None, None], self.bias)
  38. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  39. state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape)
  40. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  41. class LayerNorm2d(nn.LayerNorm):
  42. def forward(self, x: torch.Tensor):
  43. x = x.permute(0, 2, 3, 1)
  44. x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  45. x = x.permute(0, 3, 1, 2)
  46. return x
  47. class PatchMerging2D(nn.Module):
  48. def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False):
  49. super().__init__()
  50. self.dim = dim
  51. Linear = Linear2d if channel_first else nn.Linear
  52. self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last
  53. self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
  54. self.norm = norm_layer(4 * dim)
  55. @staticmethod
  56. def _patch_merging_pad_channel_last(x: torch.Tensor):
  57. H, W, _ = x.shape[-3:]
  58. if (W % 2 != 0) or (H % 2 != 0):
  59. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  60. x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
  61. x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
  62. x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
  63. x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
  64. x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
  65. return x
  66. @staticmethod
  67. def _patch_merging_pad_channel_first(x: torch.Tensor):
  68. H, W = x.shape[-2:]
  69. if (W % 2 != 0) or (H % 2 != 0):
  70. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  71. x0 = x[..., 0::2, 0::2] # ... H/2 W/2
  72. x1 = x[..., 1::2, 0::2] # ... H/2 W/2
  73. x2 = x[..., 0::2, 1::2] # ... H/2 W/2
  74. x3 = x[..., 1::2, 1::2] # ... H/2 W/2
  75. x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C
  76. return x
  77. def forward(self, x):
  78. x = self._patch_merging_pad(x)
  79. x = self.norm(x)
  80. x = self.reduction(x)
  81. return x
  82. class Permute(nn.Module):
  83. def __init__(self, *args):
  84. super().__init__()
  85. self.args = args
  86. def forward(self, x: torch.Tensor):
  87. return x.permute(*self.args)
  88. class Mlp(nn.Module):
  89. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  90. super().__init__()
  91. out_features = out_features or in_features
  92. hidden_features = hidden_features or in_features
  93. Linear = Linear2d if channels_first else nn.Linear
  94. self.fc1 = Linear(in_features, hidden_features)
  95. self.act = act_layer()
  96. self.fc2 = Linear(hidden_features, out_features)
  97. self.drop = nn.Dropout(drop)
  98. def forward(self, x):
  99. x = self.fc1(x)
  100. x = self.act(x)
  101. x = self.drop(x)
  102. x = self.fc2(x)
  103. x = self.drop(x)
  104. return x
  105. class gMlp(nn.Module):
  106. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  107. super().__init__()
  108. self.channel_first = channels_first
  109. out_features = out_features or in_features
  110. hidden_features = hidden_features or in_features
  111. Linear = Linear2d if channels_first else nn.Linear
  112. self.fc1 = Linear(in_features, 2 * hidden_features)
  113. self.act = act_layer()
  114. self.fc2 = Linear(hidden_features, out_features)
  115. self.drop = nn.Dropout(drop)
  116. def forward(self, x: torch.Tensor):
  117. x = self.fc1(x)
  118. x, z = x.chunk(2, dim=(1 if self.channel_first else -1))
  119. x = self.fc2(x * self.act(z))
  120. x = self.drop(x)
  121. return x
  122. class SoftmaxSpatial(nn.Softmax):
  123. def forward(self, x: torch.Tensor):
  124. if self.dim == -1:
  125. B, C, H, W = x.shape
  126. return super().forward(x.view(B, C, -1)).view(B, C, H, W)
  127. elif self.dim == 1:
  128. B, H, W, C = x.shape
  129. return super().forward(x.view(B, -1, C)).view(B, H, W, C)
  130. else:
  131. raise NotImplementedError
  132. # =====================================================
  133. class mamba_init:
  134. @staticmethod
  135. def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
  136. dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
  137. # Initialize special dt projection to preserve variance at initialization
  138. dt_init_std = dt_rank**-0.5 * dt_scale
  139. if dt_init == "constant":
  140. nn.init.constant_(dt_proj.weight, dt_init_std)
  141. elif dt_init == "random":
  142. nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
  143. else:
  144. raise NotImplementedError
  145. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  146. dt = torch.exp(
  147. torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
  148. + math.log(dt_min)
  149. ).clamp(min=dt_init_floor)
  150. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  151. inv_dt = dt + torch.log(-torch.expm1(-dt))
  152. with torch.no_grad():
  153. dt_proj.bias.copy_(inv_dt)
  154. # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  155. # dt_proj.bias._no_reinit = True
  156. return dt_proj
  157. @staticmethod
  158. def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
  159. # S4D real initialization
  160. A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
  161. A_log = torch.log(A) # Keep A_log in fp32
  162. if copies > 0:
  163. A_log = A_log[None].repeat(copies, 1, 1).contiguous()
  164. if merge:
  165. A_log = A_log.flatten(0, 1)
  166. A_log = nn.Parameter(A_log)
  167. A_log._no_weight_decay = True
  168. return A_log
  169. @staticmethod
  170. def D_init(d_inner, copies=-1, device=None, merge=True):
  171. # D "skip" parameter
  172. D = torch.ones(d_inner, device=device)
  173. if copies > 0:
  174. D = D[None].repeat(copies, 1).contiguous()
  175. if merge:
  176. D = D.flatten(0, 1)
  177. D = nn.Parameter(D) # Keep in fp32
  178. D._no_weight_decay = True
  179. return D
  180. @classmethod
  181. def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
  182. # dt proj ============================
  183. dt_projs = [
  184. cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
  185. for _ in range(k_group)
  186. ]
  187. dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
  188. dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
  189. del dt_projs
  190. # A, D =======================================
  191. A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
  192. Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
  193. return A_logs, Ds, dt_projs_weight, dt_projs_bias
  194. # support: v0, v0seq
  195. class SS2Dv0:
  196. def __initv0__(
  197. self,
  198. # basic dims ===========
  199. d_model=96,
  200. d_state=16,
  201. ssm_ratio=2.0,
  202. dt_rank="auto",
  203. # ======================
  204. dropout=0.0,
  205. # ======================
  206. seq=False,
  207. force_fp32=True,
  208. **kwargs,
  209. ):
  210. if "channel_first" in kwargs:
  211. assert not kwargs["channel_first"]
  212. act_layer = nn.SiLU
  213. dt_min = 0.001
  214. dt_max = 0.1
  215. dt_init = "random"
  216. dt_scale = 1.0
  217. dt_init_floor = 1e-4
  218. bias = False
  219. conv_bias = True
  220. d_conv = 3
  221. k_group = 4
  222. factory_kwargs = {"device": None, "dtype": None}
  223. super().__init__()
  224. d_inner = int(ssm_ratio * d_model)
  225. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  226. self.forward = self.forwardv0
  227. if seq:
  228. self.forward = partial(self.forwardv0, seq=True)
  229. if not force_fp32:
  230. self.forward = partial(self.forwardv0, force_fp32=False)
  231. # in proj ============================
  232. self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
  233. self.act: nn.Module = act_layer()
  234. self.conv2d = nn.Conv2d(
  235. in_channels=d_inner,
  236. out_channels=d_inner,
  237. groups=d_inner,
  238. bias=conv_bias,
  239. kernel_size=d_conv,
  240. padding=(d_conv - 1) // 2,
  241. **factory_kwargs,
  242. )
  243. # x proj ============================
  244. self.x_proj = [
  245. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  246. for _ in range(k_group)
  247. ]
  248. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  249. del self.x_proj
  250. # dt proj, A, D ============================
  251. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  252. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  253. )
  254. # out proj =======================================
  255. self.out_norm = nn.LayerNorm(d_inner)
  256. self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
  257. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  258. def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
  259. x = self.in_proj(x)
  260. x, z = x.chunk(2, dim=-1) # (b, h, w, d)
  261. z = self.act(z)
  262. x = x.permute(0, 3, 1, 2).contiguous()
  263. x = self.conv2d(x) # (b, d, h, w)
  264. x = self.act(x)
  265. selective_scan = partial(selective_scan_fn, backend="mamba")
  266. B, D, H, W = x.shape
  267. D, N = self.A_logs.shape
  268. K, D, R = self.dt_projs_weight.shape
  269. L = H * W
  270. x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
  271. xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
  272. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  273. if hasattr(self, "x_proj_bias"):
  274. x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
  275. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  276. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  277. xs = xs.view(B, -1, L) # (b, k * d, l)
  278. dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
  279. Bs = Bs.contiguous() # (b, k, d_state, l)
  280. Cs = Cs.contiguous() # (b, k, d_state, l)
  281. As = -self.A_logs.float().exp() # (k * d, d_state)
  282. Ds = self.Ds.float() # (k * d)
  283. dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
  284. # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
  285. # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
  286. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  287. if force_fp32:
  288. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  289. if seq:
  290. out_y = []
  291. for i in range(4):
  292. yi = selective_scan(
  293. xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
  294. As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
  295. delta_bias=dt_projs_bias.view(K, -1)[i],
  296. delta_softplus=True,
  297. ).view(B, -1, L)
  298. out_y.append(yi)
  299. out_y = torch.stack(out_y, dim=1)
  300. else:
  301. out_y = selective_scan(
  302. xs, dts,
  303. As, Bs, Cs, Ds,
  304. delta_bias=dt_projs_bias,
  305. delta_softplus=True,
  306. ).view(B, K, -1, L)
  307. assert out_y.dtype == torch.float
  308. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  309. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  310. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  311. y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  312. y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
  313. y = self.out_norm(y).view(B, H, W, -1)
  314. y = y * z
  315. out = self.dropout(self.out_proj(y))
  316. return out
  317. # support: v01-v05; v051d,v052d,v052dc;
  318. # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
  319. # history support: v2,v3;v31d,v32d,v32dc;
  320. class SS2Dv2:
  321. def __initv2__(
  322. self,
  323. # basic dims ===========
  324. d_model=96,
  325. d_state=16,
  326. ssm_ratio=2.0,
  327. dt_rank="auto",
  328. act_layer=nn.SiLU,
  329. # dwconv ===============
  330. d_conv=3, # < 2 means no conv
  331. conv_bias=True,
  332. # ======================
  333. dropout=0.0,
  334. bias=False,
  335. # dt init ==============
  336. dt_min=0.001,
  337. dt_max=0.1,
  338. dt_init="random",
  339. dt_scale=1.0,
  340. dt_init_floor=1e-4,
  341. initialize="v0",
  342. # ======================
  343. forward_type="v2",
  344. channel_first=False,
  345. # ======================
  346. **kwargs,
  347. ):
  348. factory_kwargs = {"device": None, "dtype": None}
  349. super().__init__()
  350. self.k_group = 4
  351. self.d_model = int(d_model)
  352. self.d_state = int(d_state)
  353. self.d_inner = int(ssm_ratio * d_model)
  354. self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
  355. self.channel_first = channel_first
  356. self.with_dconv = d_conv > 1
  357. Linear = Linear2d if channel_first else nn.Linear
  358. self.forward = self.forwardv2
  359. # tags for forward_type ==============================
  360. checkpostfix = self.checkpostfix
  361. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  362. self.oact, forward_type = checkpostfix("_oact", forward_type)
  363. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  364. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  365. self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
  366. # forward_type debug =======================================
  367. FORWARD_TYPES = dict(
  368. v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
  369. v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
  370. v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
  371. v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
  372. v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
  373. # ===============================
  374. v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
  375. v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
  376. v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
  377. v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
  378. # ===============================
  379. v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
  380. v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
  381. )
  382. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  383. # in proj =======================================
  384. d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
  385. self.in_proj = Linear(self.d_model, d_proj, bias=bias)
  386. self.act: nn.Module = act_layer()
  387. # conv =======================================
  388. if self.with_dconv:
  389. self.conv2d = nn.Conv2d(
  390. in_channels=self.d_inner,
  391. out_channels=self.d_inner,
  392. groups=self.d_inner,
  393. bias=conv_bias,
  394. kernel_size=d_conv,
  395. padding=(d_conv - 1) // 2,
  396. **factory_kwargs,
  397. )
  398. # x proj ============================
  399. self.x_proj = [
  400. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
  401. for _ in range(self.k_group)
  402. ]
  403. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  404. del self.x_proj
  405. # out proj =======================================
  406. self.out_act = nn.GELU() if self.oact else nn.Identity()
  407. self.out_proj = Linear(self.d_inner, self.d_model, bias=bias)
  408. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  409. if initialize in ["v0"]:
  410. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  411. self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
  412. )
  413. elif initialize in ["v1"]:
  414. # simple init dt_projs, A_logs, Ds
  415. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  416. self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  417. self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
  418. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
  419. elif initialize in ["v2"]:
  420. # simple init dt_projs, A_logs, Ds
  421. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  422. self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  423. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
  424. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
  425. def forward_corev2(
  426. self,
  427. x: torch.Tensor=None,
  428. # ==============================
  429. force_fp32=False, # True: input fp32
  430. # ==============================
  431. ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
  432. no_einsum=False, # replace einsum with linear or conv1d to raise throughput
  433. # ==============================
  434. selective_scan_backend = None,
  435. # ==============================
  436. scan_mode = "cross2d",
  437. scan_force_torch = False,
  438. # ==============================
  439. **kwargs,
  440. ):
  441. assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
  442. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
  443. assert isinstance(_scan_mode, int)
  444. delta_softplus = True
  445. out_norm = self.out_norm
  446. channel_first = self.channel_first
  447. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  448. B, D, H, W = x.shape
  449. N = self.d_state
  450. K, D, R = self.k_group, self.d_inner, self.dt_rank
  451. L = H * W
  452. def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
  453. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
  454. if _scan_mode == -1:
  455. x_proj_bias = getattr(self, "x_proj_bias", None)
  456. def scan_rowcol(
  457. x: torch.Tensor,
  458. proj_weight: torch.Tensor,
  459. proj_bias: torch.Tensor,
  460. dt_weight: torch.Tensor,
  461. dt_bias: torch.Tensor, # (2*c)
  462. _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state)
  463. _Ds: torch.Tensor,
  464. width = True,
  465. ):
  466. # x: (B, D, H, W)
  467. # proj_weight: (2 * D, (R+N+N))
  468. XB, XD, XH, XW = x.shape
  469. if width:
  470. _B, _D, _L = XB * XH, XD, XW
  471. xs = x.permute(0, 2, 1, 3).contiguous()
  472. else:
  473. _B, _D, _L = XB * XW, XD, XH
  474. xs = x.permute(0, 3, 1, 2).contiguous()
  475. xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W)
  476. if no_einsum:
  477. x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2)
  478. dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2)
  479. dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2)
  480. else:
  481. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight)
  482. if x_proj_bias is not None:
  483. x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1)
  484. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  485. dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight)
  486. xs = xs.view(_B, -1, _L)
  487. dts = dts.contiguous().view(_B, -1, _L)
  488. As = _As.view(-1, N).to(torch.float)
  489. Bs = Bs.contiguous().view(_B, 2, N, _L)
  490. Cs = Cs.contiguous().view(_B, 2, N, _L)
  491. Ds = _Ds.view(-1)
  492. delta_bias = dt_bias.view(-1).to(torch.float)
  493. if force_fp32:
  494. xs = xs.to(torch.float)
  495. dts = dts.to(xs.dtype)
  496. Bs = Bs.to(xs.dtype)
  497. Cs = Cs.to(xs.dtype)
  498. ys: torch.Tensor = selective_scan(
  499. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  500. ).view(_B, 2, -1, _L)
  501. return ys
  502. As = -self.A_logs.to(torch.float).exp().view(4, -1, N)
  503. x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  504. y_row = scan_rowcol(
  505. x,
  506. proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(),
  507. proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None),
  508. dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(),
  509. dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None),
  510. _As = As[:2].contiguous().view(-1, N),
  511. _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1),
  512. width=True,
  513. ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W)
  514. y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  515. y_col = scan_rowcol(
  516. y_row,
  517. proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype),
  518. proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None),
  519. dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype),
  520. dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None),
  521. _As = As[2:].contiguous().view(-1, N),
  522. _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1),
  523. width=False,
  524. ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1)
  525. y = y_col
  526. else:
  527. x_proj_bias = getattr(self, "x_proj_bias", None)
  528. xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  529. if no_einsum:
  530. x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
  531. dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
  532. if hasattr(self, "dt_projs_weight"):
  533. dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K)
  534. else:
  535. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  536. if x_proj_bias is not None:
  537. x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
  538. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  539. if hasattr(self, "dt_projs_weight"):
  540. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  541. xs = xs.view(B, -1, L)
  542. dts = dts.contiguous().view(B, -1, L)
  543. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  544. Ds = self.Ds.to(torch.float) # (K * c)
  545. Bs = Bs.contiguous().view(B, K, N, L)
  546. Cs = Cs.contiguous().view(B, K, N, L)
  547. delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
  548. if force_fp32:
  549. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  550. ys: torch.Tensor = selective_scan(
  551. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  552. ).view(B, K, -1, H, W)
  553. y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  554. if getattr(self, "__DEBUG__", False):
  555. setattr(self, "__data__", dict(
  556. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  557. us=xs, dts=dts, delta_bias=delta_bias,
  558. ys=ys, y=y, H=H, W=W,
  559. ))
  560. y = y.view(B, -1, H, W)
  561. if not channel_first:
  562. y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C)
  563. y = out_norm(y)
  564. return y.to(x.dtype)
  565. def forwardv2(self, x: torch.Tensor, **kwargs):
  566. x = self.in_proj(x)
  567. if not self.disable_z:
  568. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  569. if not self.disable_z_act:
  570. z = self.act(z)
  571. if not self.channel_first:
  572. x = x.permute(0, 3, 1, 2).contiguous()
  573. if self.with_dconv:
  574. x = self.conv2d(x) # (b, d, h, w)
  575. x = self.act(x)
  576. y = self.forward_core(x)
  577. y = self.out_act(y)
  578. if not self.disable_z:
  579. y = y * z
  580. out = self.dropout(self.out_proj(y))
  581. return out
  582. @staticmethod
  583. def get_outnorm(forward_type="", d_inner=192, channel_first=True):
  584. def checkpostfix(tag, value):
  585. ret = value[-len(tag):] == tag
  586. if ret:
  587. value = value[:-len(tag)]
  588. return ret, value
  589. LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm
  590. out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
  591. out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
  592. out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
  593. out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
  594. out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
  595. out_norm = nn.Identity()
  596. if out_norm_none:
  597. out_norm = nn.Identity()
  598. elif out_norm_cnorm:
  599. out_norm = nn.Sequential(
  600. LayerNorm(d_inner),
  601. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  602. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  603. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  604. )
  605. elif out_norm_dwconv3:
  606. out_norm = nn.Sequential(
  607. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  608. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  609. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  610. )
  611. elif out_norm_softmax:
  612. out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
  613. elif out_norm_sigmoid:
  614. out_norm = nn.Sigmoid()
  615. else:
  616. out_norm = LayerNorm(d_inner)
  617. return out_norm, forward_type
  618. @staticmethod
  619. def checkpostfix(tag, value):
  620. ret = value[-len(tag):] == tag
  621. if ret:
  622. value = value[:-len(tag)]
  623. return ret, value
  624. # support: xv1a,xv2a,xv3a;
  625. # postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone;
  626. class SS2Dv3:
  627. def __initxv__(
  628. self,
  629. # basic dims ===========
  630. d_model=96,
  631. d_state=16,
  632. ssm_ratio=2.0,
  633. dt_rank="auto",
  634. # dwconv ===============
  635. d_conv=3, # < 2 means no conv
  636. conv_bias=True,
  637. # ======================
  638. dropout=0.0,
  639. bias=False,
  640. # dt init ==============
  641. dt_min=0.001,
  642. dt_max=0.1,
  643. dt_init="random",
  644. dt_scale=1.0,
  645. dt_init_floor=1e-4,
  646. initialize="v0",
  647. # ======================
  648. forward_type="v2",
  649. channel_first=False,
  650. # ======================
  651. **kwargs,
  652. ):
  653. super().__init__()
  654. d_inner = int(ssm_ratio * d_model)
  655. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  656. self.channel_first = channel_first
  657. self.d_state = d_state
  658. self.dt_rank = dt_rank
  659. self.d_inner = d_inner
  660. k_group = 4
  661. self.with_dconv = d_conv > 1
  662. Linear = Linear2d if channel_first else nn.Linear
  663. self.forward = self.forwardxv
  664. # tags for forward_type ==============================
  665. checkpostfix = SS2Dv2.checkpostfix
  666. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first)
  667. self.omul, forward_type = checkpostfix("_mul", forward_type)
  668. self.oact, forward_type = checkpostfix("_act", forward_type)
  669. self.f_omul = nn.Identity() if self.omul else None
  670. self.out_act = nn.GELU() if self.oact else nn.Identity()
  671. mode = forward_type[:4]
  672. assert mode in ["xv1a", "xv2a", "xv3a"]
  673. self.forward = partial(self.forwardxv, mode=mode)
  674. self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode]
  675. d_inner_all = d_inner + self.dts_dim + 8 * d_state
  676. self.in_proj = Linear(d_model, d_inner_all, bias=bias)
  677. # conv =======================================
  678. self.cpos = False
  679. self.iconv = False
  680. self.oconv = False
  681. self.oconv2 = False
  682. if self.with_dconv:
  683. cact, forward_type = checkpostfix("_ca", forward_type)
  684. cact1, forward_type = checkpostfix("_ca1", forward_type)
  685. self.cact = nn.SiLU() if cact else nn.Identity()
  686. self.cact = nn.GELU() if cact1 else self.cact
  687. self.oconv2, forward_type = checkpostfix("_ocov2", forward_type)
  688. self.oconv, forward_type = checkpostfix("_ocov", forward_type)
  689. self.cpos, forward_type = checkpostfix("_cpos", forward_type)
  690. self.iconv = (not self.oconv) and (not self.oconv2)
  691. if self.iconv:
  692. self.conv2d = nn.Conv2d(
  693. in_channels=d_model,
  694. out_channels=d_model,
  695. groups=d_model,
  696. bias=conv_bias,
  697. kernel_size=d_conv,
  698. padding=(d_conv - 1) // 2,
  699. )
  700. if self.oconv:
  701. self.oconv2d = nn.Conv2d(
  702. in_channels=d_inner,
  703. out_channels=d_inner,
  704. groups=d_inner,
  705. bias=conv_bias,
  706. kernel_size=d_conv,
  707. padding=(d_conv - 1) // 2,
  708. )
  709. if self.oconv2:
  710. self.conv2d = nn.Conv2d(
  711. in_channels=d_inner_all,
  712. out_channels=d_inner_all,
  713. groups=d_inner_all,
  714. bias=conv_bias,
  715. kernel_size=d_conv,
  716. padding=(d_conv - 1) // 2,
  717. )
  718. # out proj =======================================
  719. self.out_proj = Linear(d_inner, d_model, bias=bias)
  720. self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
  721. if initialize in ["v0"]:
  722. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  723. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  724. )
  725. elif initialize in ["v1"]:
  726. # simple init dt_projs, A_logs, Ds
  727. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  728. self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  729. self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
  730. self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
  731. elif initialize in ["v2"]:
  732. # simple init dt_projs, A_logs, Ds
  733. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  734. self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  735. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank)))
  736. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner)))
  737. if forward_type.startswith("xv2"):
  738. del self.dt_projs_weight
  739. self.dt_projs_weight = None
  740. def forwardxv(self, x: torch.Tensor, **kwargs):
  741. B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3])
  742. L = H * W
  743. force_fp32 = False
  744. delta_softplus = True
  745. out_norm = self.out_norm
  746. to_dtype = True
  747. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  748. def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):
  749. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None)
  750. if self.iconv:
  751. x = self.cact(self.conv2d(x)) # (b, d, h, w)
  752. elif self.cpos:
  753. x = x + self.conv2d(x) # (b, d, h, w)
  754. x = self.in_proj(x)
  755. if self.oconv2:
  756. x = self.conv2d(x) # (b, d, h, w)
  757. us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1))
  758. _us = us
  759. # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1)
  760. # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W)
  761. us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L)
  762. Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  763. Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  764. dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1)
  765. if self.dts_dim == self.dt_rank:
  766. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  767. elif self.dts_dim == 4 * self.dt_rank:
  768. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  769. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  770. Ds = self.Ds.to(torch.float) # (K * c)
  771. delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c)
  772. if force_fp32:
  773. us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs)
  774. ys: torch.Tensor = selective_scan(
  775. us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  776. ).view(B, 4, -1, H, W)
  777. y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True)
  778. y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1)
  779. y = out_norm(y)
  780. if getattr(self, "__DEBUG__", False):
  781. setattr(self, "__data__", dict(
  782. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  783. us=us, dts=dts, delta_bias=delta_bias,
  784. ys=ys, y=y,
  785. ))
  786. y = (y.to(x.dtype) if to_dtype else y)
  787. y = self.out_act(y)
  788. if self.omul:
  789. y = y * _us
  790. if self.oconv:
  791. y = y + self.cact(self.oconv2d(_us))
  792. out = self.dropout(self.out_proj(y))
  793. return out
  794. # mamba2 support ================================
  795. class SS2Dm0:
  796. def __initm0__(
  797. self,
  798. # basic dims ===========
  799. d_model=96,
  800. d_state=16, # now with mamba2, dstate should be bigger...
  801. ssm_ratio=2.0,
  802. dt_rank="auto",
  803. act_layer=nn.GELU,
  804. # dwconv ===============
  805. d_conv=3, # < 2 means no conv
  806. conv_bias=True,
  807. # ======================
  808. dropout=0.0,
  809. bias=False,
  810. # dt init ==============
  811. dt_min=0.001,
  812. dt_max=0.1,
  813. dt_init="random",
  814. dt_scale=1.0,
  815. dt_init_floor=1e-4,
  816. initialize="v2",
  817. # ======================
  818. forward_type="m0",
  819. # ======================
  820. with_initial_state=False,
  821. # ======================
  822. **kwargs,
  823. ):
  824. factory_kwargs = {"device": None, "dtype": None}
  825. super().__init__()
  826. d_inner = int(ssm_ratio * d_model)
  827. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  828. assert d_inner % dt_rank == 0
  829. self.with_dconv = d_conv > 1
  830. Linear = nn.Linear
  831. self.forward = self.forwardm0
  832. # tags for forward_type ==============================
  833. checkpostfix = SS2Dv2.checkpostfix
  834. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  835. self.oact, forward_type = checkpostfix("_oact", forward_type)
  836. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  837. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  838. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False)
  839. # forward_type debug =======================================
  840. FORWARD_TYPES = dict(
  841. m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state),
  842. )
  843. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  844. k_group = 4
  845. # in proj =======================================
  846. d_proj = d_inner if self.disable_z else (d_inner * 2)
  847. self.in_proj = Linear(d_model, d_proj, bias=bias)
  848. self.act: nn.Module = act_layer()
  849. # conv =======================================
  850. if self.with_dconv:
  851. self.conv2d = nn.Sequential(
  852. Permute(0, 3, 1, 2),
  853. nn.Conv2d(
  854. in_channels=d_inner,
  855. out_channels=d_inner,
  856. groups=d_inner,
  857. bias=conv_bias,
  858. kernel_size=d_conv,
  859. padding=(d_conv - 1) // 2,
  860. **factory_kwargs,
  861. ),
  862. Permute(0, 2, 3, 1),
  863. )
  864. # x proj ============================
  865. self.x_proj = [
  866. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  867. for _ in range(k_group)
  868. ]
  869. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  870. del self.x_proj
  871. # out proj =======================================
  872. self.out_act = nn.GELU() if self.oact else nn.Identity()
  873. self.out_proj = Linear(d_inner, d_model, bias=bias)
  874. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  875. if initialize in ["v1"]:
  876. # simple init dt_projs, A_logs, Ds
  877. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  878. self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  879. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430
  880. elif initialize in ["v2"]:
  881. # simple init dt_projs, A_logs, Ds
  882. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  883. self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  884. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank)))
  885. # init state ============================
  886. self.initial_state = None
  887. if with_initial_state:
  888. self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False)
  889. def forward_corem0(
  890. self,
  891. x: torch.Tensor=None,
  892. # ==============================
  893. force_fp32=False, # True: input fp32
  894. chunk_size = 64,
  895. dstate = 64,
  896. # ==============================
  897. selective_scan_backend = None,
  898. scan_mode = "cross2d",
  899. scan_force_torch = False,
  900. # ==============================
  901. **kwargs,
  902. ):
  903. assert scan_mode in ["unidi", "bidi", "cross2d"]
  904. assert selective_scan_backend in [None, "triton", "torch"]
  905. x_proj_bias = getattr(self, "x_proj_bias", None)
  906. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  907. N = dstate
  908. B, H, W, RD = x.shape
  909. K, R = self.A_logs.shape
  910. K, R, D = self.Ds.shape
  911. assert RD == R * D
  912. L = H * W
  913. KR = K * R
  914. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode]
  915. initial_state = None
  916. if self.initial_state is not None:
  917. assert self.initial_state.shape[-1] == dstate
  918. initial_state = self.initial_state.detach().repeat(B, 1, 1, 1)
  919. xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D)
  920. x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight)
  921. if x_proj_bias is not None:
  922. x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1)
  923. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3)
  924. xs = xs.contiguous().view(B, L, KR, D)
  925. dts = dts.contiguous().view(B, L, KR)
  926. Bs = Bs.contiguous().view(B, L, K, N)
  927. Cs = Cs.contiguous().view(B, L, K, N)
  928. if force_fp32:
  929. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  930. As = -self.A_logs.to(torch.float).exp().view(KR)
  931. Ds = self.Ds.to(torch.float).view(KR, D)
  932. dt_bias = self.dt_projs_bias.view(KR)
  933. if force_fp32:
  934. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  935. ys, final_state = selective_scan_chunk_fn(
  936. xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias,
  937. initial_states=initial_state, dt_softplus=True, return_final_states=True,
  938. backend=selective_scan_backend,
  939. )
  940. y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch)
  941. if getattr(self, "__DEBUG__", False):
  942. setattr(self, "__data__", dict(
  943. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds,
  944. us=xs, dts=dts, delta_bias=self.dt_projs_bias,
  945. initial_state=self.initial_state, final_satte=final_state,
  946. ys=ys, y=y, H=H, W=W,
  947. ))
  948. if self.initial_state is not None:
  949. self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False)
  950. y = self.out_norm(y.view(B, H, W, -1))
  951. return y.to(x.dtype)
  952. def forwardm0(self, x: torch.Tensor, **kwargs):
  953. x = self.in_proj(x)
  954. if not self.disable_z:
  955. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  956. if not self.disable_z_act:
  957. z = self.act(z)
  958. if self.with_dconv:
  959. x = self.conv2d(x) # (b, d, h, w)
  960. x = self.act(x)
  961. y = self.forward_core(x)
  962. y = self.out_act(y)
  963. if not self.disable_z:
  964. y = y * z
  965. out = self.dropout(self.out_proj(y))
  966. return out
  967. class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0):
  968. def __init__(
  969. self,
  970. # basic dims ===========
  971. d_model=96,
  972. d_state=16,
  973. ssm_ratio=2.0,
  974. dt_rank="auto",
  975. act_layer=nn.SiLU,
  976. # dwconv ===============
  977. d_conv=3, # < 2 means no conv
  978. conv_bias=True,
  979. # ======================
  980. dropout=0.0,
  981. bias=False,
  982. # dt init ==============
  983. dt_min=0.001,
  984. dt_max=0.1,
  985. dt_init="random",
  986. dt_scale=1.0,
  987. dt_init_floor=1e-4,
  988. initialize="v0",
  989. # ======================
  990. forward_type="v2",
  991. channel_first=False,
  992. # ======================
  993. **kwargs,
  994. ):
  995. nn.Module.__init__(self)
  996. kwargs.update(
  997. d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
  998. act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
  999. dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
  1000. initialize=initialize, forward_type=forward_type, channel_first=channel_first,
  1001. )
  1002. if forward_type in ["v0", "v0seq"]:
  1003. self.__initv0__(seq=("seq" in forward_type), **kwargs)
  1004. elif forward_type.startswith("xv"):
  1005. self.__initxv__(**kwargs)
  1006. elif forward_type.startswith("m"):
  1007. self.__initm0__(**kwargs)
  1008. else:
  1009. self.__initv2__(**kwargs)
  1010. # =====================================================
  1011. class VSSBlock(nn.Module):
  1012. def __init__(
  1013. self,
  1014. hidden_dim: int = 0,
  1015. drop_path: float = 0,
  1016. norm_layer: nn.Module = nn.LayerNorm,
  1017. channel_first=False,
  1018. # =============================
  1019. ssm_d_state: int = 16,
  1020. ssm_ratio=2.0,
  1021. ssm_dt_rank: Any = "auto",
  1022. ssm_act_layer=nn.SiLU,
  1023. ssm_conv: int = 3,
  1024. ssm_conv_bias=True,
  1025. ssm_drop_rate: float = 0,
  1026. ssm_init="v0",
  1027. forward_type="v2",
  1028. # =============================
  1029. mlp_ratio=4.0,
  1030. mlp_act_layer=nn.GELU,
  1031. mlp_drop_rate: float = 0.0,
  1032. gmlp=False,
  1033. # =============================
  1034. use_checkpoint: bool = False,
  1035. post_norm: bool = False,
  1036. # =============================
  1037. _SS2D: type = SS2D,
  1038. **kwargs,
  1039. ):
  1040. super().__init__()
  1041. self.ssm_branch = ssm_ratio > 0
  1042. self.mlp_branch = mlp_ratio > 0
  1043. self.use_checkpoint = use_checkpoint
  1044. self.post_norm = post_norm
  1045. if self.ssm_branch:
  1046. self.norm = norm_layer(hidden_dim)
  1047. self.op = _SS2D(
  1048. d_model=hidden_dim,
  1049. d_state=ssm_d_state,
  1050. ssm_ratio=ssm_ratio,
  1051. dt_rank=ssm_dt_rank,
  1052. act_layer=ssm_act_layer,
  1053. # ==========================
  1054. d_conv=ssm_conv,
  1055. conv_bias=ssm_conv_bias,
  1056. # ==========================
  1057. dropout=ssm_drop_rate,
  1058. # bias=False,
  1059. # ==========================
  1060. # dt_min=0.001,
  1061. # dt_max=0.1,
  1062. # dt_init="random",
  1063. # dt_scale="random",
  1064. # dt_init_floor=1e-4,
  1065. initialize=ssm_init,
  1066. # ==========================
  1067. forward_type=forward_type,
  1068. channel_first=channel_first,
  1069. )
  1070. self.drop_path = DropPath(drop_path)
  1071. if self.mlp_branch:
  1072. _MLP = Mlp if not gmlp else gMlp
  1073. self.norm2 = norm_layer(hidden_dim)
  1074. mlp_hidden_dim = int(hidden_dim * mlp_ratio)
  1075. self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first)
  1076. def _forward(self, input: torch.Tensor):
  1077. x = input
  1078. if self.ssm_branch:
  1079. if self.post_norm:
  1080. x = x + self.drop_path(self.norm(self.op(x)))
  1081. else:
  1082. x = x + self.drop_path(self.op(self.norm(x)))
  1083. if self.mlp_branch:
  1084. if self.post_norm:
  1085. x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
  1086. else:
  1087. x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
  1088. return x
  1089. def forward(self, input: torch.Tensor):
  1090. if self.use_checkpoint:
  1091. return checkpoint.checkpoint(self._forward, input)
  1092. else:
  1093. return self._forward(input)
  1094. class VSSM(nn.Module):
  1095. def __init__(
  1096. self,
  1097. patch_size=4,
  1098. in_chans=3,
  1099. num_classes=1000,
  1100. depths=[2, 2, 9, 2],
  1101. dims=[96, 192, 384, 768],
  1102. # =========================
  1103. ssm_d_state=16,
  1104. ssm_ratio=2.0,
  1105. ssm_dt_rank="auto",
  1106. ssm_act_layer="silu",
  1107. ssm_conv=3,
  1108. ssm_conv_bias=True,
  1109. ssm_drop_rate=0.0,
  1110. ssm_init="v0",
  1111. forward_type="v2",
  1112. # =========================
  1113. mlp_ratio=4.0,
  1114. mlp_act_layer="gelu",
  1115. mlp_drop_rate=0.0,
  1116. gmlp=False,
  1117. # =========================
  1118. drop_path_rate=0.1,
  1119. patch_norm=True,
  1120. norm_layer="LN", # "BN", "LN2D"
  1121. downsample_version: str = "v2", # "v1", "v2", "v3"
  1122. patchembed_version: str = "v1", # "v1", "v2"
  1123. use_checkpoint=False,
  1124. # =========================
  1125. posembed=False,
  1126. imgsize=224,
  1127. _SS2D=SS2D,
  1128. # =========================
  1129. **kwargs,
  1130. ):
  1131. super().__init__()
  1132. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1133. self.num_classes = num_classes
  1134. self.num_layers = len(depths)
  1135. if isinstance(dims, int):
  1136. dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
  1137. self.num_features = dims[-1]
  1138. self.dims = dims
  1139. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  1140. _NORMLAYERS = dict(
  1141. ln=nn.LayerNorm,
  1142. ln2d=LayerNorm2d,
  1143. bn=nn.BatchNorm2d,
  1144. )
  1145. _ACTLAYERS = dict(
  1146. silu=nn.SiLU,
  1147. gelu=nn.GELU,
  1148. relu=nn.ReLU,
  1149. sigmoid=nn.Sigmoid,
  1150. )
  1151. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1152. ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
  1153. mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
  1154. self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
  1155. _make_patch_embed = dict(
  1156. v1=self._make_patch_embed,
  1157. v2=self._make_patch_embed_v2,
  1158. ).get(patchembed_version, None)
  1159. self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first)
  1160. _make_downsample = dict(
  1161. v1=PatchMerging2D,
  1162. v2=self._make_downsample,
  1163. v3=self._make_downsample_v3,
  1164. none=(lambda *_, **_k: None),
  1165. ).get(downsample_version, None)
  1166. self.layers = nn.ModuleList()
  1167. for i_layer in range(self.num_layers):
  1168. downsample = _make_downsample(
  1169. self.dims[i_layer],
  1170. self.dims[i_layer + 1],
  1171. norm_layer=norm_layer,
  1172. channel_first=self.channel_first,
  1173. ) if (i_layer < self.num_layers - 1) else nn.Identity()
  1174. self.layers.append(self._make_layer(
  1175. dim = self.dims[i_layer],
  1176. drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  1177. use_checkpoint=use_checkpoint,
  1178. norm_layer=norm_layer,
  1179. downsample=downsample,
  1180. channel_first=self.channel_first,
  1181. # =================
  1182. ssm_d_state=ssm_d_state,
  1183. ssm_ratio=ssm_ratio,
  1184. ssm_dt_rank=ssm_dt_rank,
  1185. ssm_act_layer=ssm_act_layer,
  1186. ssm_conv=ssm_conv,
  1187. ssm_conv_bias=ssm_conv_bias,
  1188. ssm_drop_rate=ssm_drop_rate,
  1189. ssm_init=ssm_init,
  1190. forward_type=forward_type,
  1191. # =================
  1192. mlp_ratio=mlp_ratio,
  1193. mlp_act_layer=mlp_act_layer,
  1194. mlp_drop_rate=mlp_drop_rate,
  1195. gmlp=gmlp,
  1196. # =================
  1197. _SS2D=_SS2D,
  1198. ))
  1199. self.classifier = nn.Sequential(OrderedDict(
  1200. norm=norm_layer(self.num_features), # B,H,W,C
  1201. permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
  1202. avgpool=nn.AdaptiveAvgPool2d(1),
  1203. flatten=nn.Flatten(1),
  1204. head=nn.Linear(self.num_features, num_classes),
  1205. ))
  1206. self.apply(self._init_weights)
  1207. @staticmethod
  1208. def _pos_embed(embed_dims, patch_size, img_size):
  1209. patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
  1210. pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
  1211. trunc_normal_(pos_embed, std=0.02)
  1212. return pos_embed
  1213. def _init_weights(self, m: nn.Module):
  1214. if isinstance(m, nn.Linear):
  1215. trunc_normal_(m.weight, std=.02)
  1216. if isinstance(m, nn.Linear) and m.bias is not None:
  1217. nn.init.constant_(m.bias, 0)
  1218. elif isinstance(m, nn.LayerNorm):
  1219. nn.init.constant_(m.bias, 0)
  1220. nn.init.constant_(m.weight, 1.0)
  1221. # used in building optimizer
  1222. @torch.jit.ignore
  1223. def no_weight_decay(self):
  1224. return {"pos_embed"}
  1225. # used in building optimizer
  1226. @torch.jit.ignore
  1227. def no_weight_decay_keywords(self):
  1228. return {}
  1229. @staticmethod
  1230. def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1231. # if channel first, then Norm and Output are both channel_first
  1232. return nn.Sequential(
  1233. nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
  1234. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1235. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1236. )
  1237. @staticmethod
  1238. def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1239. # if channel first, then Norm and Output are both channel_first
  1240. stride = patch_size // 2
  1241. kernel_size = stride + 1
  1242. padding = 1
  1243. return nn.Sequential(
  1244. nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
  1245. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)),
  1246. (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
  1247. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)),
  1248. nn.GELU(),
  1249. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
  1250. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1251. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1252. )
  1253. @staticmethod
  1254. def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1255. # if channel first, then Norm and Output are both channel_first
  1256. return nn.Sequential(
  1257. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1258. nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
  1259. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1260. norm_layer(out_dim),
  1261. )
  1262. @staticmethod
  1263. def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1264. # if channel first, then Norm and Output are both channel_first
  1265. return nn.Sequential(
  1266. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1267. nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
  1268. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1269. norm_layer(out_dim),
  1270. )
  1271. @staticmethod
  1272. def _make_layer(
  1273. dim=96,
  1274. drop_path=[0.1, 0.1],
  1275. use_checkpoint=False,
  1276. norm_layer=nn.LayerNorm,
  1277. downsample=nn.Identity(),
  1278. channel_first=False,
  1279. # ===========================
  1280. ssm_d_state=16,
  1281. ssm_ratio=2.0,
  1282. ssm_dt_rank="auto",
  1283. ssm_act_layer=nn.SiLU,
  1284. ssm_conv=3,
  1285. ssm_conv_bias=True,
  1286. ssm_drop_rate=0.0,
  1287. ssm_init="v0",
  1288. forward_type="v2",
  1289. # ===========================
  1290. mlp_ratio=4.0,
  1291. mlp_act_layer=nn.GELU,
  1292. mlp_drop_rate=0.0,
  1293. gmlp=False,
  1294. # ===========================
  1295. _SS2D=SS2D,
  1296. **kwargs,
  1297. ):
  1298. # if channel first, then Norm and Output are both channel_first
  1299. depth = len(drop_path)
  1300. blocks = []
  1301. for d in range(depth):
  1302. blocks.append(VSSBlock(
  1303. hidden_dim=dim,
  1304. drop_path=drop_path[d],
  1305. norm_layer=norm_layer,
  1306. channel_first=channel_first,
  1307. ssm_d_state=ssm_d_state,
  1308. ssm_ratio=ssm_ratio,
  1309. ssm_dt_rank=ssm_dt_rank,
  1310. ssm_act_layer=ssm_act_layer,
  1311. ssm_conv=ssm_conv,
  1312. ssm_conv_bias=ssm_conv_bias,
  1313. ssm_drop_rate=ssm_drop_rate,
  1314. ssm_init=ssm_init,
  1315. forward_type=forward_type,
  1316. mlp_ratio=mlp_ratio,
  1317. mlp_act_layer=mlp_act_layer,
  1318. mlp_drop_rate=mlp_drop_rate,
  1319. gmlp=gmlp,
  1320. use_checkpoint=use_checkpoint,
  1321. _SS2D=_SS2D,
  1322. ))
  1323. return nn.Sequential(OrderedDict(
  1324. blocks=nn.Sequential(*blocks,),
  1325. downsample=downsample,
  1326. ))
  1327. def forward(self, x: torch.Tensor):
  1328. x = self.patch_embed(x)
  1329. if self.pos_embed is not None:
  1330. pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
  1331. x = x + pos_embed
  1332. for layer in self.layers:
  1333. x = layer(x)
  1334. x = self.classifier(x)
  1335. return x
  1336. def flops(self, shape=(3, 224, 224), verbose=True):
  1337. # shape = self.__input_shape__[1:]
  1338. supported_ops={
  1339. "aten::silu": None, # as relu is in _IGNORED_OPS
  1340. "aten::neg": None, # as relu is in _IGNORED_OPS
  1341. "aten::exp": None, # as relu is in _IGNORED_OPS
  1342. "aten::flip": None, # as permute is in _IGNORED_OPS
  1343. # "prim::PythonOp.CrossScan": None,
  1344. # "prim::PythonOp.CrossMerge": None,
  1345. "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
  1346. }
  1347. model = copy.deepcopy(self)
  1348. model.cuda().eval()
  1349. input = torch.randn((1, *shape), device=next(model.parameters()).device)
  1350. params = parameter_count(model)[""]
  1351. Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
  1352. del model, input
  1353. return sum(Gflops.values()) * 1e9
  1354. return f"params {params} GFLOPs {sum(Gflops.values())}"
  1355. # used to load ckpt from previous training code
  1356. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  1357. def check_name(src, state_dict: dict = state_dict, strict=False):
  1358. if strict:
  1359. if prefix + src in list(state_dict.keys()):
  1360. return True
  1361. else:
  1362. key = prefix + src
  1363. for k in list(state_dict.keys()):
  1364. if k.startswith(key):
  1365. return True
  1366. return False
  1367. def change_name(src, dst, state_dict: dict = state_dict, strict=False):
  1368. if strict:
  1369. if prefix + src in list(state_dict.keys()):
  1370. state_dict[prefix + dst] = state_dict[prefix + src]
  1371. state_dict.pop(prefix + src)
  1372. else:
  1373. key = prefix + src
  1374. for k in list(state_dict.keys()):
  1375. if k.startswith(key):
  1376. new_k = prefix + dst + k[len(key):]
  1377. state_dict[new_k] = state_dict[k]
  1378. state_dict.pop(k)
  1379. if check_name("pos_embed", strict=True):
  1380. srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
  1381. state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
  1382. change_name("patch_embed.proj", "patch_embed.0")
  1383. change_name("patch_embed.norm", "patch_embed.2")
  1384. for i in range(100):
  1385. for j in range(100):
  1386. change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
  1387. change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
  1388. change_name("norm", "classifier.norm")
  1389. change_name("head", "classifier.head")
  1390. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  1391. # compatible with openmmlab
  1392. class Backbone_VSSM(VSSM):
  1393. def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
  1394. kwargs.update(norm_layer=norm_layer)
  1395. super().__init__(**kwargs)
  1396. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1397. _NORMLAYERS = dict(
  1398. ln=nn.LayerNorm,
  1399. ln2d=LayerNorm2d,
  1400. bn=nn.BatchNorm2d,
  1401. )
  1402. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1403. self.out_indices = out_indices
  1404. for i in out_indices:
  1405. layer = norm_layer(self.dims[i])
  1406. layer_name = f'outnorm{i}'
  1407. self.add_module(layer_name, layer)
  1408. del self.classifier
  1409. self.load_pretrained(pretrained)
  1410. def load_pretrained(self, ckpt=None, key="model"):
  1411. if ckpt is None:
  1412. return
  1413. try:
  1414. _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
  1415. print(f"Successfully load ckpt {ckpt}")
  1416. incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
  1417. print(incompatibleKeys)
  1418. except Exception as e:
  1419. print(f"Failed loading checkpoint form {ckpt}: {e}")
  1420. def forward(self, x):
  1421. def layer_forward(l, x):
  1422. x = l.blocks(x)
  1423. y = l.downsample(x)
  1424. return x, y
  1425. x = self.patch_embed(x)
  1426. outs = []
  1427. for i, layer in enumerate(self.layers):
  1428. o, x = layer_forward(layer, x) # (B, H, W, C)
  1429. if i in self.out_indices:
  1430. norm_layer = getattr(self, f'outnorm{i}')
  1431. out = norm_layer(o)
  1432. if not self.channel_first:
  1433. out = out.permute(0, 3, 1, 2)
  1434. outs.append(out.contiguous())
  1435. if len(self.out_indices) == 0:
  1436. return x
  1437. return outs
  1438. # =====================================================
  1439. def vanilla_vmamba_tiny():
  1440. return VSSM(
  1441. depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
  1442. patch_size=4, in_chans=3, num_classes=1000,
  1443. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1444. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1445. ssm_init="v0", forward_type="v0",
  1446. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1447. patch_norm=True, norm_layer="ln",
  1448. downsample_version="v1", patchembed_version="v1",
  1449. use_checkpoint=False, posembed=False, imgsize=224,
  1450. )
  1451. def vanilla_vmamba_small():
  1452. return VSSM(
  1453. depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
  1454. patch_size=4, in_chans=3, num_classes=1000,
  1455. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1456. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1457. ssm_init="v0", forward_type="v0",
  1458. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1459. patch_norm=True, norm_layer="ln",
  1460. downsample_version="v1", patchembed_version="v1",
  1461. use_checkpoint=False, posembed=False, imgsize=224,
  1462. )
  1463. def vanilla_vmamba_base():
  1464. return VSSM(
  1465. depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
  1466. patch_size=4, in_chans=3, num_classes=1000,
  1467. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1468. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1469. ssm_init="v0", forward_type="v0",
  1470. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1471. patch_norm=True, norm_layer="ln",
  1472. downsample_version="v1", patchembed_version="v1",
  1473. use_checkpoint=False, posembed=False, imgsize=224,
  1474. )
  1475. # =====================================================
  1476. def vmamba_tiny_s2l5(channel_first=True):
  1477. return VSSM(
  1478. depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
  1479. patch_size=4, in_chans=3, num_classes=1000,
  1480. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1481. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1482. ssm_init="v0", forward_type="v05_noz",
  1483. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1484. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1485. downsample_version="v3", patchembed_version="v2",
  1486. use_checkpoint=False, posembed=False, imgsize=224,
  1487. )
  1488. def vmamba_small_s2l15(channel_first=True):
  1489. return VSSM(
  1490. depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
  1491. patch_size=4, in_chans=3, num_classes=1000,
  1492. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1493. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1494. ssm_init="v0", forward_type="v05_noz",
  1495. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1496. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1497. downsample_version="v3", patchembed_version="v2",
  1498. use_checkpoint=False, posembed=False, imgsize=224,
  1499. )
  1500. def vmamba_base_s2l15(channel_first=True):
  1501. return VSSM(
  1502. depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
  1503. patch_size=4, in_chans=3, num_classes=1000,
  1504. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1505. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1506. ssm_init="v0", forward_type="v05_noz",
  1507. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1508. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1509. downsample_version="v3", patchembed_version="v2",
  1510. use_checkpoint=False, posembed=False, imgsize=224,
  1511. )
  1512. # =====================================================
  1513. def vmamba_tiny_s1l8(channel_first=True):
  1514. return VSSM(
  1515. depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
  1516. patch_size=4, in_chans=3, num_classes=1000,
  1517. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1518. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1519. ssm_init="v0", forward_type="v05_noz",
  1520. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1521. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1522. downsample_version="v3", patchembed_version="v2",
  1523. use_checkpoint=False, posembed=False, imgsize=224,
  1524. )
  1525. def vmamba_small_s1l20(channel_first=True):
  1526. return VSSM(
  1527. depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
  1528. patch_size=4, in_chans=3, num_classes=1000,
  1529. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1530. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1531. ssm_init="v0", forward_type="v05_noz",
  1532. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1533. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1534. downsample_version="v3", patchembed_version="v2",
  1535. use_checkpoint=False, posembed=False, imgsize=224,
  1536. )
  1537. def vmamba_base_s1l20(channel_first=True):
  1538. return VSSM(
  1539. depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
  1540. patch_size=4, in_chans=3, num_classes=1000,
  1541. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1542. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1543. ssm_init="v0", forward_type="v05_noz",
  1544. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1545. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1546. downsample_version="v3", patchembed_version="v2",
  1547. use_checkpoint=False, posembed=False, imgsize=224,
  1548. )
  1549. # mamba2 support =====================================================
  1550. # FLOPS count do not work now for mamba2!
  1551. def vmamba_tiny_m2():
  1552. return VSSM(
  1553. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1554. patch_size=4, in_chans=3, num_classes=1000,
  1555. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1556. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1557. ssm_init="v2", forward_type="m0_noz",
  1558. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1559. patch_norm=True, norm_layer="ln",
  1560. downsample_version="v3", patchembed_version="v2",
  1561. use_checkpoint=False, posembed=False, imgsize=224,
  1562. )
  1563. def vmamba_small_m2():
  1564. return VSSM(
  1565. depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3,
  1566. patch_size=4, in_chans=3, num_classes=1000,
  1567. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1568. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1569. ssm_init="v2", forward_type="m0_noz",
  1570. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1571. patch_norm=True, norm_layer="ln",
  1572. downsample_version="v3", patchembed_version="v2",
  1573. use_checkpoint=False, posembed=False, imgsize=224,
  1574. )
  1575. def vmamba_base_m2():
  1576. return VSSM(
  1577. depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3,
  1578. patch_size=4, in_chans=3, num_classes=1000,
  1579. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1580. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1581. ssm_init="v2", forward_type="m0_noz",
  1582. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1583. patch_norm=True, norm_layer="ln",
  1584. downsample_version="v3", patchembed_version="v2",
  1585. use_checkpoint=False, posembed=False, imgsize=224,
  1586. )
  1587. if __name__ == "__main__":
  1588. model_ref = vmamba_tiny_s1l8()
  1589. model = VSSM(
  1590. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1591. patch_size=4, in_chans=3, num_classes=1000,
  1592. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1593. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1594. ssm_init="v2", forward_type="m0_noz",
  1595. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1596. patch_norm=True, norm_layer="ln",
  1597. downsample_version="v3", patchembed_version="v2",
  1598. use_checkpoint=False, posembed=False, imgsize=224,
  1599. )
  1600. print(parameter_count(model)[""])
  1601. print(model.flops()) # wrong
  1602. model.cuda().train()
  1603. model_ref.cuda().train()
  1604. def bench(model):
  1605. import time
  1606. inp = torch.randn((128, 3, 224, 224)).cuda()
  1607. for _ in range(30):
  1608. model(inp)
  1609. torch.cuda.synchronize()
  1610. tim = time.time()
  1611. for _ in range(30):
  1612. model(inp)
  1613. torch.cuda.synchronize()
  1614. tim1 = time.time() - tim
  1615. for _ in range(30):
  1616. model(inp).sum().backward()
  1617. torch.cuda.synchronize()
  1618. tim = time.time()
  1619. for _ in range(30):
  1620. model(inp).sum().backward()
  1621. torch.cuda.synchronize()
  1622. tim2 = time.time() - tim
  1623. return tim1 / 30, tim2 / 30
  1624. print(bench(model_ref))
  1625. print(bench(model))
  1626. breakpoint()