vmamba.py 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849
  1. import os
  2. import time
  3. import math
  4. import copy
  5. from functools import partial
  6. from typing import Optional, Callable, Any
  7. from collections import OrderedDict
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torch.utils.checkpoint as checkpoint
  12. from timm.layers import DropPath, trunc_normal_
  13. try:
  14. from .csm_triton import cross_scan_fn, cross_merge_fn
  15. except:
  16. from csm_triton import cross_scan_fn, cross_merge_fn
  17. try:
  18. from .csms6s import selective_scan_fn, selective_scan_flop_jit
  19. except:
  20. from csms6s import selective_scan_fn, selective_scan_flop_jit
  21. # FLOPs counter not prepared for mamba2.
  22. # Keep this dependency optional because SS2D(v2/v3) does not require it.
  23. try:
  24. from .mamba2.ssd_minimal import selective_scan_chunk_fn
  25. except Exception:
  26. try:
  27. from mamba2.ssd_minimal import selective_scan_chunk_fn
  28. except Exception:
  29. selective_scan_chunk_fn = None
  30. # =====================================================
  31. # we have this class as linear and conv init differ from each other
  32. # this function enable loading from both conv2d or linear
  33. class Linear2d(nn.Linear):
  34. def forward(self, x: torch.Tensor):
  35. # B, C, H, W = x.shape
  36. return F.conv2d(x, self.weight[:, :, None, None], self.bias)
  37. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  38. state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape)
  39. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  40. class LayerNorm2d(nn.LayerNorm):
  41. def forward(self, x: torch.Tensor):
  42. x = x.permute(0, 2, 3, 1)
  43. x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  44. x = x.permute(0, 3, 1, 2)
  45. return x
  46. class PatchMerging2D(nn.Module):
  47. def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False):
  48. super().__init__()
  49. self.dim = dim
  50. Linear = Linear2d if channel_first else nn.Linear
  51. self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last
  52. self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
  53. self.norm = norm_layer(4 * dim)
  54. @staticmethod
  55. def _patch_merging_pad_channel_last(x: torch.Tensor):
  56. H, W, _ = x.shape[-3:]
  57. if (W % 2 != 0) or (H % 2 != 0):
  58. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  59. x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
  60. x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
  61. x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
  62. x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
  63. x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
  64. return x
  65. @staticmethod
  66. def _patch_merging_pad_channel_first(x: torch.Tensor):
  67. H, W = x.shape[-2:]
  68. if (W % 2 != 0) or (H % 2 != 0):
  69. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  70. x0 = x[..., 0::2, 0::2] # ... H/2 W/2
  71. x1 = x[..., 1::2, 0::2] # ... H/2 W/2
  72. x2 = x[..., 0::2, 1::2] # ... H/2 W/2
  73. x3 = x[..., 1::2, 1::2] # ... H/2 W/2
  74. x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C
  75. return x
  76. def forward(self, x):
  77. x = self._patch_merging_pad(x)
  78. x = self.norm(x)
  79. x = self.reduction(x)
  80. return x
  81. class Permute(nn.Module):
  82. def __init__(self, *args):
  83. super().__init__()
  84. self.args = args
  85. def forward(self, x: torch.Tensor):
  86. return x.permute(*self.args)
  87. class Mlp(nn.Module):
  88. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  89. super().__init__()
  90. out_features = out_features or in_features
  91. hidden_features = hidden_features or in_features
  92. Linear = Linear2d if channels_first else nn.Linear
  93. self.fc1 = Linear(in_features, hidden_features)
  94. self.act = act_layer()
  95. self.fc2 = Linear(hidden_features, out_features)
  96. self.drop = nn.Dropout(drop)
  97. def forward(self, x):
  98. x = self.fc1(x)
  99. x = self.act(x)
  100. x = self.drop(x)
  101. x = self.fc2(x)
  102. x = self.drop(x)
  103. return x
  104. class gMlp(nn.Module):
  105. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  106. super().__init__()
  107. self.channel_first = channels_first
  108. out_features = out_features or in_features
  109. hidden_features = hidden_features or in_features
  110. Linear = Linear2d if channels_first else nn.Linear
  111. self.fc1 = Linear(in_features, 2 * hidden_features)
  112. self.act = act_layer()
  113. self.fc2 = Linear(hidden_features, out_features)
  114. self.drop = nn.Dropout(drop)
  115. def forward(self, x: torch.Tensor):
  116. x = self.fc1(x)
  117. x, z = x.chunk(2, dim=(1 if self.channel_first else -1))
  118. x = self.fc2(x * self.act(z))
  119. x = self.drop(x)
  120. return x
  121. class SoftmaxSpatial(nn.Softmax):
  122. def forward(self, x: torch.Tensor):
  123. if self.dim == -1:
  124. B, C, H, W = x.shape
  125. return super().forward(x.view(B, C, -1)).view(B, C, H, W)
  126. elif self.dim == 1:
  127. B, H, W, C = x.shape
  128. return super().forward(x.view(B, -1, C)).view(B, H, W, C)
  129. else:
  130. raise NotImplementedError
  131. # =====================================================
  132. class mamba_init:
  133. @staticmethod
  134. def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
  135. dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
  136. # Initialize special dt projection to preserve variance at initialization
  137. dt_init_std = dt_rank**-0.5 * dt_scale
  138. if dt_init == "constant":
  139. nn.init.constant_(dt_proj.weight, dt_init_std)
  140. elif dt_init == "random":
  141. nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
  142. else:
  143. raise NotImplementedError
  144. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  145. dt = torch.exp(
  146. torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
  147. + math.log(dt_min)
  148. ).clamp(min=dt_init_floor)
  149. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  150. inv_dt = dt + torch.log(-torch.expm1(-dt))
  151. with torch.no_grad():
  152. dt_proj.bias.copy_(inv_dt)
  153. # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  154. # dt_proj.bias._no_reinit = True
  155. return dt_proj
  156. @staticmethod
  157. def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
  158. # S4D real initialization
  159. A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
  160. A_log = torch.log(A) # Keep A_log in fp32
  161. if copies > 0:
  162. A_log = A_log[None].repeat(copies, 1, 1).contiguous()
  163. if merge:
  164. A_log = A_log.flatten(0, 1)
  165. A_log = nn.Parameter(A_log)
  166. A_log._no_weight_decay = True
  167. return A_log
  168. @staticmethod
  169. def D_init(d_inner, copies=-1, device=None, merge=True):
  170. # D "skip" parameter
  171. D = torch.ones(d_inner, device=device)
  172. if copies > 0:
  173. D = D[None].repeat(copies, 1).contiguous()
  174. if merge:
  175. D = D.flatten(0, 1)
  176. D = nn.Parameter(D) # Keep in fp32
  177. D._no_weight_decay = True
  178. return D
  179. @classmethod
  180. def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
  181. # dt proj ============================
  182. dt_projs = [
  183. cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
  184. for _ in range(k_group)
  185. ]
  186. dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
  187. dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
  188. del dt_projs
  189. # A, D =======================================
  190. A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
  191. Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
  192. return A_logs, Ds, dt_projs_weight, dt_projs_bias
  193. # support: v0, v0seq
  194. class SS2Dv0:
  195. def __initv0__(
  196. self,
  197. # basic dims ===========
  198. d_model=96,
  199. d_state=16,
  200. ssm_ratio=2.0,
  201. dt_rank="auto",
  202. # ======================
  203. dropout=0.0,
  204. # ======================
  205. seq=False,
  206. force_fp32=True,
  207. **kwargs,
  208. ):
  209. if "channel_first" in kwargs:
  210. assert not kwargs["channel_first"]
  211. act_layer = nn.SiLU
  212. dt_min = 0.001
  213. dt_max = 0.1
  214. dt_init = "random"
  215. dt_scale = 1.0
  216. dt_init_floor = 1e-4
  217. bias = False
  218. conv_bias = True
  219. d_conv = 3
  220. k_group = 4
  221. factory_kwargs = {"device": None, "dtype": None}
  222. super().__init__()
  223. d_inner = int(ssm_ratio * d_model)
  224. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  225. self.forward = self.forwardv0
  226. if seq:
  227. self.forward = partial(self.forwardv0, seq=True)
  228. if not force_fp32:
  229. self.forward = partial(self.forwardv0, force_fp32=False)
  230. # in proj ============================
  231. self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
  232. self.act: nn.Module = act_layer()
  233. self.conv2d = nn.Conv2d(
  234. in_channels=d_inner,
  235. out_channels=d_inner,
  236. groups=d_inner,
  237. bias=conv_bias,
  238. kernel_size=d_conv,
  239. padding=(d_conv - 1) // 2,
  240. **factory_kwargs,
  241. )
  242. # x proj ============================
  243. self.x_proj = [
  244. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  245. for _ in range(k_group)
  246. ]
  247. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  248. del self.x_proj
  249. # dt proj, A, D ============================
  250. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  251. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  252. )
  253. # out proj =======================================
  254. self.out_norm = nn.LayerNorm(d_inner)
  255. self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
  256. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  257. def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
  258. x = self.in_proj(x)
  259. x, z = x.chunk(2, dim=-1) # (b, h, w, d)
  260. z = self.act(z)
  261. x = x.permute(0, 3, 1, 2).contiguous()
  262. x = self.conv2d(x) # (b, d, h, w)
  263. x = self.act(x)
  264. selective_scan = partial(selective_scan_fn, backend="mamba")
  265. B, D, H, W = x.shape
  266. D, N = self.A_logs.shape
  267. K, D, R = self.dt_projs_weight.shape
  268. L = H * W
  269. x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
  270. xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
  271. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  272. if hasattr(self, "x_proj_bias"):
  273. x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
  274. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  275. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  276. xs = xs.view(B, -1, L) # (b, k * d, l)
  277. dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
  278. Bs = Bs.contiguous() # (b, k, d_state, l)
  279. Cs = Cs.contiguous() # (b, k, d_state, l)
  280. As = -self.A_logs.float().exp() # (k * d, d_state)
  281. Ds = self.Ds.float() # (k * d)
  282. dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
  283. # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
  284. # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
  285. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  286. if force_fp32:
  287. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  288. if seq:
  289. out_y = []
  290. for i in range(4):
  291. yi = selective_scan(
  292. xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
  293. As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
  294. delta_bias=dt_projs_bias.view(K, -1)[i],
  295. delta_softplus=True,
  296. ).view(B, -1, L)
  297. out_y.append(yi)
  298. out_y = torch.stack(out_y, dim=1)
  299. else:
  300. out_y = selective_scan(
  301. xs, dts,
  302. As, Bs, Cs, Ds,
  303. delta_bias=dt_projs_bias,
  304. delta_softplus=True,
  305. ).view(B, K, -1, L)
  306. assert out_y.dtype == torch.float
  307. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  308. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  309. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  310. y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  311. y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
  312. y = self.out_norm(y).view(B, H, W, -1)
  313. y = y * z
  314. out = self.dropout(self.out_proj(y))
  315. return out
  316. # support: v01-v05; v051d,v052d,v052dc;
  317. # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
  318. # history support: v2,v3;v31d,v32d,v32dc;
  319. class SS2Dv2:
  320. def __initv2__(
  321. self,
  322. # basic dims ===========
  323. d_model=96,
  324. d_state=16,
  325. ssm_ratio=2.0,
  326. dt_rank="auto",
  327. act_layer=nn.SiLU,
  328. # dwconv ===============
  329. d_conv=3, # < 2 means no conv
  330. conv_bias=True,
  331. # ======================
  332. dropout=0.0,
  333. bias=False,
  334. # dt init ==============
  335. dt_min=0.001,
  336. dt_max=0.1,
  337. dt_init="random",
  338. dt_scale=1.0,
  339. dt_init_floor=1e-4,
  340. initialize="v0",
  341. # ======================
  342. forward_type="v2",
  343. channel_first=False,
  344. # ======================
  345. **kwargs,
  346. ):
  347. factory_kwargs = {"device": None, "dtype": None}
  348. super().__init__()
  349. self.k_group = 4
  350. self.d_model = int(d_model)
  351. self.d_state = int(d_state)
  352. self.d_inner = int(ssm_ratio * d_model)
  353. self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
  354. self.channel_first = channel_first
  355. self.with_dconv = d_conv > 1
  356. Linear = Linear2d if channel_first else nn.Linear
  357. self.forward = self.forwardv2
  358. # tags for forward_type ==============================
  359. checkpostfix = self.checkpostfix
  360. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  361. self.oact, forward_type = checkpostfix("_oact", forward_type)
  362. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  363. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  364. self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
  365. # forward_type debug =======================================
  366. FORWARD_TYPES = dict(
  367. v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
  368. v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
  369. v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
  370. v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
  371. v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
  372. # ===============================
  373. v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
  374. v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
  375. v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
  376. v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
  377. # ===============================
  378. v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
  379. v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
  380. )
  381. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  382. # in proj =======================================
  383. d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
  384. self.in_proj = Linear(self.d_model, d_proj, bias=bias)
  385. self.act: nn.Module = act_layer()
  386. # conv =======================================
  387. if self.with_dconv:
  388. self.conv2d = nn.Conv2d(
  389. in_channels=self.d_inner,
  390. out_channels=self.d_inner,
  391. groups=self.d_inner,
  392. bias=conv_bias,
  393. kernel_size=d_conv,
  394. padding=(d_conv - 1) // 2,
  395. **factory_kwargs,
  396. )
  397. # x proj ============================
  398. self.x_proj = [
  399. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
  400. for _ in range(self.k_group)
  401. ]
  402. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  403. del self.x_proj
  404. # out proj =======================================
  405. self.out_act = nn.GELU() if self.oact else nn.Identity()
  406. self.out_proj = Linear(self.d_inner, self.d_model, bias=bias)
  407. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  408. if initialize in ["v0"]:
  409. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  410. self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
  411. )
  412. elif initialize in ["v1"]:
  413. # simple init dt_projs, A_logs, Ds
  414. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  415. self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  416. self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
  417. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
  418. elif initialize in ["v2"]:
  419. # simple init dt_projs, A_logs, Ds
  420. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  421. self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  422. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
  423. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
  424. def forward_corev2(
  425. self,
  426. x: torch.Tensor=None,
  427. # ==============================
  428. force_fp32=False, # True: input fp32
  429. # ==============================
  430. ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
  431. no_einsum=False, # replace einsum with linear or conv1d to raise throughput
  432. # ==============================
  433. selective_scan_backend = None,
  434. # ==============================
  435. scan_mode = "cross2d",
  436. scan_force_torch = False,
  437. # ==============================
  438. **kwargs,
  439. ):
  440. assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
  441. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
  442. assert isinstance(_scan_mode, int)
  443. delta_softplus = True
  444. out_norm = self.out_norm
  445. channel_first = self.channel_first
  446. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  447. B, D, H, W = x.shape
  448. N = self.d_state
  449. K, D, R = self.k_group, self.d_inner, self.dt_rank
  450. L = H * W
  451. def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
  452. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
  453. if _scan_mode == -1:
  454. x_proj_bias = getattr(self, "x_proj_bias", None)
  455. def scan_rowcol(
  456. x: torch.Tensor,
  457. proj_weight: torch.Tensor,
  458. proj_bias: torch.Tensor,
  459. dt_weight: torch.Tensor,
  460. dt_bias: torch.Tensor, # (2*c)
  461. _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state)
  462. _Ds: torch.Tensor,
  463. width = True,
  464. ):
  465. # x: (B, D, H, W)
  466. # proj_weight: (2 * D, (R+N+N))
  467. XB, XD, XH, XW = x.shape
  468. if width:
  469. _B, _D, _L = XB * XH, XD, XW
  470. xs = x.permute(0, 2, 1, 3).contiguous()
  471. else:
  472. _B, _D, _L = XB * XW, XD, XH
  473. xs = x.permute(0, 3, 1, 2).contiguous()
  474. xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W)
  475. if no_einsum:
  476. x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2)
  477. dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2)
  478. dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2)
  479. else:
  480. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight)
  481. if x_proj_bias is not None:
  482. x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1)
  483. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  484. dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight)
  485. xs = xs.view(_B, -1, _L)
  486. dts = dts.contiguous().view(_B, -1, _L)
  487. As = _As.view(-1, N).to(torch.float)
  488. Bs = Bs.contiguous().view(_B, 2, N, _L)
  489. Cs = Cs.contiguous().view(_B, 2, N, _L)
  490. Ds = _Ds.view(-1)
  491. delta_bias = dt_bias.view(-1).to(torch.float)
  492. if force_fp32:
  493. xs = xs.to(torch.float)
  494. dts = dts.to(xs.dtype)
  495. Bs = Bs.to(xs.dtype)
  496. Cs = Cs.to(xs.dtype)
  497. ys: torch.Tensor = selective_scan(
  498. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  499. ).view(_B, 2, -1, _L)
  500. return ys
  501. As = -self.A_logs.to(torch.float).exp().view(4, -1, N)
  502. x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  503. y_row = scan_rowcol(
  504. x,
  505. proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(),
  506. proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None),
  507. dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(),
  508. dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None),
  509. _As = As[:2].contiguous().view(-1, N),
  510. _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1),
  511. width=True,
  512. ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W)
  513. y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  514. y_col = scan_rowcol(
  515. y_row,
  516. proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype),
  517. proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None),
  518. dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype),
  519. dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None),
  520. _As = As[2:].contiguous().view(-1, N),
  521. _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1),
  522. width=False,
  523. ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1)
  524. y = y_col
  525. else:
  526. x_proj_bias = getattr(self, "x_proj_bias", None)
  527. xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  528. if no_einsum:
  529. x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
  530. dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
  531. if hasattr(self, "dt_projs_weight"):
  532. dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K)
  533. else:
  534. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  535. if x_proj_bias is not None:
  536. x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
  537. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  538. if hasattr(self, "dt_projs_weight"):
  539. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  540. xs = xs.view(B, -1, L)
  541. dts = dts.contiguous().view(B, -1, L)
  542. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  543. Ds = self.Ds.to(torch.float) # (K * c)
  544. Bs = Bs.contiguous().view(B, K, N, L)
  545. Cs = Cs.contiguous().view(B, K, N, L)
  546. delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
  547. if force_fp32:
  548. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  549. ys: torch.Tensor = selective_scan(
  550. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  551. ).view(B, K, -1, H, W)
  552. y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  553. if getattr(self, "__DEBUG__", False):
  554. setattr(self, "__data__", dict(
  555. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  556. us=xs, dts=dts, delta_bias=delta_bias,
  557. ys=ys, y=y, H=H, W=W,
  558. ))
  559. y = y.view(B, -1, H, W)
  560. if not channel_first:
  561. y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C)
  562. y = out_norm(y)
  563. return y.to(x.dtype)
  564. def forwardv2(self, x: torch.Tensor, **kwargs):
  565. x = self.in_proj(x)
  566. if not self.disable_z:
  567. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  568. if not self.disable_z_act:
  569. z = self.act(z)
  570. if not self.channel_first:
  571. x = x.permute(0, 3, 1, 2).contiguous()
  572. if self.with_dconv:
  573. x = self.conv2d(x) # (b, d, h, w)
  574. x = self.act(x)
  575. y = self.forward_core(x)
  576. y = self.out_act(y)
  577. if not self.disable_z:
  578. y = y * z
  579. out = self.dropout(self.out_proj(y))
  580. return out
  581. @staticmethod
  582. def get_outnorm(forward_type="", d_inner=192, channel_first=True):
  583. def checkpostfix(tag, value):
  584. ret = value[-len(tag):] == tag
  585. if ret:
  586. value = value[:-len(tag)]
  587. return ret, value
  588. LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm
  589. out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
  590. out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
  591. out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
  592. out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
  593. out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
  594. out_norm = nn.Identity()
  595. if out_norm_none:
  596. out_norm = nn.Identity()
  597. elif out_norm_cnorm:
  598. out_norm = nn.Sequential(
  599. LayerNorm(d_inner),
  600. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  601. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  602. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  603. )
  604. elif out_norm_dwconv3:
  605. out_norm = nn.Sequential(
  606. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  607. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  608. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  609. )
  610. elif out_norm_softmax:
  611. out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
  612. elif out_norm_sigmoid:
  613. out_norm = nn.Sigmoid()
  614. else:
  615. out_norm = LayerNorm(d_inner)
  616. return out_norm, forward_type
  617. @staticmethod
  618. def checkpostfix(tag, value):
  619. ret = value[-len(tag):] == tag
  620. if ret:
  621. value = value[:-len(tag)]
  622. return ret, value
  623. # support: xv1a,xv2a,xv3a;
  624. # postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone;
  625. class SS2Dv3:
  626. def __initxv__(
  627. self,
  628. # basic dims ===========
  629. d_model=96,
  630. d_state=16,
  631. ssm_ratio=2.0,
  632. dt_rank="auto",
  633. # dwconv ===============
  634. d_conv=3, # < 2 means no conv
  635. conv_bias=True,
  636. # ======================
  637. dropout=0.0,
  638. bias=False,
  639. # dt init ==============
  640. dt_min=0.001,
  641. dt_max=0.1,
  642. dt_init="random",
  643. dt_scale=1.0,
  644. dt_init_floor=1e-4,
  645. initialize="v0",
  646. # ======================
  647. forward_type="v2",
  648. channel_first=False,
  649. # ======================
  650. **kwargs,
  651. ):
  652. super().__init__()
  653. d_inner = int(ssm_ratio * d_model)
  654. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  655. self.channel_first = channel_first
  656. self.d_state = d_state
  657. self.dt_rank = dt_rank
  658. self.d_inner = d_inner
  659. k_group = 4
  660. self.with_dconv = d_conv > 1
  661. Linear = Linear2d if channel_first else nn.Linear
  662. self.forward = self.forwardxv
  663. # tags for forward_type ==============================
  664. checkpostfix = SS2Dv2.checkpostfix
  665. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first)
  666. self.omul, forward_type = checkpostfix("_mul", forward_type)
  667. self.oact, forward_type = checkpostfix("_act", forward_type)
  668. self.f_omul = nn.Identity() if self.omul else None
  669. self.out_act = nn.GELU() if self.oact else nn.Identity()
  670. mode = forward_type[:4]
  671. assert mode in ["xv1a", "xv2a", "xv3a"]
  672. self.forward = partial(self.forwardxv, mode=mode)
  673. self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode]
  674. d_inner_all = d_inner + self.dts_dim + 8 * d_state
  675. self.in_proj = Linear(d_model, d_inner_all, bias=bias)
  676. # conv =======================================
  677. self.cpos = False
  678. self.iconv = False
  679. self.oconv = False
  680. self.oconv2 = False
  681. if self.with_dconv:
  682. cact, forward_type = checkpostfix("_ca", forward_type)
  683. cact1, forward_type = checkpostfix("_ca1", forward_type)
  684. self.cact = nn.SiLU() if cact else nn.Identity()
  685. self.cact = nn.GELU() if cact1 else self.cact
  686. self.oconv2, forward_type = checkpostfix("_ocov2", forward_type)
  687. self.oconv, forward_type = checkpostfix("_ocov", forward_type)
  688. self.cpos, forward_type = checkpostfix("_cpos", forward_type)
  689. self.iconv = (not self.oconv) and (not self.oconv2)
  690. if self.iconv:
  691. self.conv2d = nn.Conv2d(
  692. in_channels=d_model,
  693. out_channels=d_model,
  694. groups=d_model,
  695. bias=conv_bias,
  696. kernel_size=d_conv,
  697. padding=(d_conv - 1) // 2,
  698. )
  699. if self.oconv:
  700. self.oconv2d = nn.Conv2d(
  701. in_channels=d_inner,
  702. out_channels=d_inner,
  703. groups=d_inner,
  704. bias=conv_bias,
  705. kernel_size=d_conv,
  706. padding=(d_conv - 1) // 2,
  707. )
  708. if self.oconv2:
  709. self.conv2d = nn.Conv2d(
  710. in_channels=d_inner_all,
  711. out_channels=d_inner_all,
  712. groups=d_inner_all,
  713. bias=conv_bias,
  714. kernel_size=d_conv,
  715. padding=(d_conv - 1) // 2,
  716. )
  717. # out proj =======================================
  718. self.out_proj = Linear(d_inner, d_model, bias=bias)
  719. self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
  720. if initialize in ["v0"]:
  721. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  722. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  723. )
  724. elif initialize in ["v1"]:
  725. # simple init dt_projs, A_logs, Ds
  726. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  727. self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  728. self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
  729. self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
  730. elif initialize in ["v2"]:
  731. # simple init dt_projs, A_logs, Ds
  732. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  733. self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  734. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank)))
  735. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner)))
  736. if forward_type.startswith("xv2"):
  737. del self.dt_projs_weight
  738. self.dt_projs_weight = None
  739. def forwardxv(self, x: torch.Tensor, **kwargs):
  740. B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3])
  741. L = H * W
  742. force_fp32 = False
  743. delta_softplus = True
  744. out_norm = self.out_norm
  745. to_dtype = True
  746. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  747. def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):
  748. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None)
  749. if self.iconv:
  750. x = self.cact(self.conv2d(x)) # (b, d, h, w)
  751. elif self.cpos:
  752. x = x + self.conv2d(x) # (b, d, h, w)
  753. x = self.in_proj(x)
  754. if self.oconv2:
  755. x = self.conv2d(x) # (b, d, h, w)
  756. us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1))
  757. _us = us
  758. # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1)
  759. # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W)
  760. us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L)
  761. Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  762. Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  763. dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1)
  764. if self.dts_dim == self.dt_rank:
  765. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  766. elif self.dts_dim == 4 * self.dt_rank:
  767. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  768. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  769. Ds = self.Ds.to(torch.float) # (K * c)
  770. delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c)
  771. if force_fp32:
  772. us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs)
  773. ys: torch.Tensor = selective_scan(
  774. us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  775. ).view(B, 4, -1, H, W)
  776. y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True)
  777. y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1)
  778. y = out_norm(y)
  779. if getattr(self, "__DEBUG__", False):
  780. setattr(self, "__data__", dict(
  781. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  782. us=us, dts=dts, delta_bias=delta_bias,
  783. ys=ys, y=y,
  784. ))
  785. y = (y.to(x.dtype) if to_dtype else y)
  786. y = self.out_act(y)
  787. if self.omul:
  788. y = y * _us
  789. if self.oconv:
  790. y = y + self.cact(self.oconv2d(_us))
  791. out = self.dropout(self.out_proj(y))
  792. return out
  793. # mamba2 support ================================
  794. class SS2Dm0:
  795. def __initm0__(
  796. self,
  797. # basic dims ===========
  798. d_model=96,
  799. d_state=16, # now with mamba2, dstate should be bigger...
  800. ssm_ratio=2.0,
  801. dt_rank="auto",
  802. act_layer=nn.GELU,
  803. # dwconv ===============
  804. d_conv=3, # < 2 means no conv
  805. conv_bias=True,
  806. # ======================
  807. dropout=0.0,
  808. bias=False,
  809. # dt init ==============
  810. dt_min=0.001,
  811. dt_max=0.1,
  812. dt_init="random",
  813. dt_scale=1.0,
  814. dt_init_floor=1e-4,
  815. initialize="v2",
  816. # ======================
  817. forward_type="m0",
  818. # ======================
  819. with_initial_state=False,
  820. # ======================
  821. **kwargs,
  822. ):
  823. factory_kwargs = {"device": None, "dtype": None}
  824. super().__init__()
  825. d_inner = int(ssm_ratio * d_model)
  826. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  827. assert d_inner % dt_rank == 0
  828. self.with_dconv = d_conv > 1
  829. Linear = nn.Linear
  830. self.forward = self.forwardm0
  831. # tags for forward_type ==============================
  832. checkpostfix = SS2Dv2.checkpostfix
  833. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  834. self.oact, forward_type = checkpostfix("_oact", forward_type)
  835. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  836. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  837. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False)
  838. # forward_type debug =======================================
  839. FORWARD_TYPES = dict(
  840. m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state),
  841. )
  842. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  843. k_group = 4
  844. # in proj =======================================
  845. d_proj = d_inner if self.disable_z else (d_inner * 2)
  846. self.in_proj = Linear(d_model, d_proj, bias=bias)
  847. self.act: nn.Module = act_layer()
  848. # conv =======================================
  849. if self.with_dconv:
  850. self.conv2d = nn.Sequential(
  851. Permute(0, 3, 1, 2),
  852. nn.Conv2d(
  853. in_channels=d_inner,
  854. out_channels=d_inner,
  855. groups=d_inner,
  856. bias=conv_bias,
  857. kernel_size=d_conv,
  858. padding=(d_conv - 1) // 2,
  859. **factory_kwargs,
  860. ),
  861. Permute(0, 2, 3, 1),
  862. )
  863. # x proj ============================
  864. self.x_proj = [
  865. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  866. for _ in range(k_group)
  867. ]
  868. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  869. del self.x_proj
  870. # out proj =======================================
  871. self.out_act = nn.GELU() if self.oact else nn.Identity()
  872. self.out_proj = Linear(d_inner, d_model, bias=bias)
  873. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  874. if initialize in ["v1"]:
  875. # simple init dt_projs, A_logs, Ds
  876. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  877. self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  878. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430
  879. elif initialize in ["v2"]:
  880. # simple init dt_projs, A_logs, Ds
  881. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  882. self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  883. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank)))
  884. # init state ============================
  885. self.initial_state = None
  886. if with_initial_state:
  887. self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False)
  888. def forward_corem0(
  889. self,
  890. x: torch.Tensor=None,
  891. # ==============================
  892. force_fp32=False, # True: input fp32
  893. chunk_size = 64,
  894. dstate = 64,
  895. # ==============================
  896. selective_scan_backend = None,
  897. scan_mode = "cross2d",
  898. scan_force_torch = False,
  899. # ==============================
  900. **kwargs,
  901. ):
  902. assert scan_mode in ["unidi", "bidi", "cross2d"]
  903. assert selective_scan_backend in [None, "triton", "torch"]
  904. x_proj_bias = getattr(self, "x_proj_bias", None)
  905. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  906. N = dstate
  907. B, H, W, RD = x.shape
  908. K, R = self.A_logs.shape
  909. K, R, D = self.Ds.shape
  910. assert RD == R * D
  911. L = H * W
  912. KR = K * R
  913. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode]
  914. initial_state = None
  915. if self.initial_state is not None:
  916. assert self.initial_state.shape[-1] == dstate
  917. initial_state = self.initial_state.detach().repeat(B, 1, 1, 1)
  918. xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D)
  919. x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight)
  920. if x_proj_bias is not None:
  921. x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1)
  922. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3)
  923. xs = xs.contiguous().view(B, L, KR, D)
  924. dts = dts.contiguous().view(B, L, KR)
  925. Bs = Bs.contiguous().view(B, L, K, N)
  926. Cs = Cs.contiguous().view(B, L, K, N)
  927. if force_fp32:
  928. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  929. As = -self.A_logs.to(torch.float).exp().view(KR)
  930. Ds = self.Ds.to(torch.float).view(KR, D)
  931. dt_bias = self.dt_projs_bias.view(KR)
  932. if force_fp32:
  933. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  934. ys, final_state = selective_scan_chunk_fn(
  935. xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias,
  936. initial_states=initial_state, dt_softplus=True, return_final_states=True,
  937. backend=selective_scan_backend,
  938. )
  939. y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch)
  940. if getattr(self, "__DEBUG__", False):
  941. setattr(self, "__data__", dict(
  942. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds,
  943. us=xs, dts=dts, delta_bias=self.dt_projs_bias,
  944. initial_state=self.initial_state, final_satte=final_state,
  945. ys=ys, y=y, H=H, W=W,
  946. ))
  947. if self.initial_state is not None:
  948. self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False)
  949. y = self.out_norm(y.view(B, H, W, -1))
  950. return y.to(x.dtype)
  951. def forwardm0(self, x: torch.Tensor, **kwargs):
  952. x = self.in_proj(x)
  953. if not self.disable_z:
  954. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  955. if not self.disable_z_act:
  956. z = self.act(z)
  957. if self.with_dconv:
  958. x = self.conv2d(x) # (b, d, h, w)
  959. x = self.act(x)
  960. y = self.forward_core(x)
  961. y = self.out_act(y)
  962. if not self.disable_z:
  963. y = y * z
  964. out = self.dropout(self.out_proj(y))
  965. return out
  966. class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0):
  967. def __init__(
  968. self,
  969. # basic dims ===========
  970. d_model=96,
  971. d_state=16,
  972. ssm_ratio=2.0,
  973. dt_rank="auto",
  974. act_layer=nn.SiLU,
  975. # dwconv ===============
  976. d_conv=3, # < 2 means no conv
  977. conv_bias=True,
  978. # ======================
  979. dropout=0.0,
  980. bias=False,
  981. # dt init ==============
  982. dt_min=0.001,
  983. dt_max=0.1,
  984. dt_init="random",
  985. dt_scale=1.0,
  986. dt_init_floor=1e-4,
  987. initialize="v0",
  988. # ======================
  989. forward_type="v2",
  990. channel_first=False,
  991. # ======================
  992. **kwargs,
  993. ):
  994. nn.Module.__init__(self)
  995. kwargs.update(
  996. d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
  997. act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
  998. dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
  999. initialize=initialize, forward_type=forward_type, channel_first=channel_first,
  1000. )
  1001. if forward_type in ["v0", "v0seq"]:
  1002. self.__initv0__(seq=("seq" in forward_type), **kwargs)
  1003. elif forward_type.startswith("xv"):
  1004. self.__initxv__(**kwargs)
  1005. elif forward_type.startswith("m"):
  1006. self.__initm0__(**kwargs)
  1007. else:
  1008. self.__initv2__(**kwargs)
  1009. # =====================================================
  1010. class VSSBlock(nn.Module):
  1011. def __init__(
  1012. self,
  1013. hidden_dim: int = 0,
  1014. drop_path: float = 0,
  1015. norm_layer: nn.Module = nn.LayerNorm,
  1016. channel_first=False,
  1017. # =============================
  1018. ssm_d_state: int = 16,
  1019. ssm_ratio=2.0,
  1020. ssm_dt_rank: Any = "auto",
  1021. ssm_act_layer=nn.SiLU,
  1022. ssm_conv: int = 3,
  1023. ssm_conv_bias=True,
  1024. ssm_drop_rate: float = 0,
  1025. ssm_init="v0",
  1026. forward_type="v2",
  1027. # =============================
  1028. mlp_ratio=4.0,
  1029. mlp_act_layer=nn.GELU,
  1030. mlp_drop_rate: float = 0.0,
  1031. gmlp=False,
  1032. # =============================
  1033. use_checkpoint: bool = False,
  1034. post_norm: bool = False,
  1035. # =============================
  1036. _SS2D: type = SS2D,
  1037. **kwargs,
  1038. ):
  1039. super().__init__()
  1040. self.ssm_branch = ssm_ratio > 0
  1041. self.mlp_branch = mlp_ratio > 0
  1042. self.use_checkpoint = use_checkpoint
  1043. self.post_norm = post_norm
  1044. if self.ssm_branch:
  1045. self.norm = norm_layer(hidden_dim)
  1046. self.op = _SS2D(
  1047. d_model=hidden_dim,
  1048. d_state=ssm_d_state,
  1049. ssm_ratio=ssm_ratio,
  1050. dt_rank=ssm_dt_rank,
  1051. act_layer=ssm_act_layer,
  1052. # ==========================
  1053. d_conv=ssm_conv,
  1054. conv_bias=ssm_conv_bias,
  1055. # ==========================
  1056. dropout=ssm_drop_rate,
  1057. # bias=False,
  1058. # ==========================
  1059. # dt_min=0.001,
  1060. # dt_max=0.1,
  1061. # dt_init="random",
  1062. # dt_scale="random",
  1063. # dt_init_floor=1e-4,
  1064. initialize=ssm_init,
  1065. # ==========================
  1066. forward_type=forward_type,
  1067. channel_first=channel_first,
  1068. )
  1069. self.drop_path = DropPath(drop_path)
  1070. if self.mlp_branch:
  1071. _MLP = Mlp if not gmlp else gMlp
  1072. self.norm2 = norm_layer(hidden_dim)
  1073. mlp_hidden_dim = int(hidden_dim * mlp_ratio)
  1074. self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first)
  1075. def _forward(self, input: torch.Tensor):
  1076. x = input
  1077. if self.ssm_branch:
  1078. if self.post_norm:
  1079. x = x + self.drop_path(self.norm(self.op(x)))
  1080. else:
  1081. x = x + self.drop_path(self.op(self.norm(x)))
  1082. if self.mlp_branch:
  1083. if self.post_norm:
  1084. x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
  1085. else:
  1086. x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
  1087. return x
  1088. def forward(self, input: torch.Tensor):
  1089. if self.use_checkpoint:
  1090. return checkpoint.checkpoint(self._forward, input)
  1091. else:
  1092. return self._forward(input)
  1093. class VSSM(nn.Module):
  1094. def __init__(
  1095. self,
  1096. patch_size=4,
  1097. in_chans=3,
  1098. num_classes=1000,
  1099. depths=[2, 2, 9, 2],
  1100. dims=[96, 192, 384, 768],
  1101. # =========================
  1102. ssm_d_state=16,
  1103. ssm_ratio=2.0,
  1104. ssm_dt_rank="auto",
  1105. ssm_act_layer="silu",
  1106. ssm_conv=3,
  1107. ssm_conv_bias=True,
  1108. ssm_drop_rate=0.0,
  1109. ssm_init="v0",
  1110. forward_type="v2",
  1111. # =========================
  1112. mlp_ratio=4.0,
  1113. mlp_act_layer="gelu",
  1114. mlp_drop_rate=0.0,
  1115. gmlp=False,
  1116. # =========================
  1117. drop_path_rate=0.1,
  1118. patch_norm=True,
  1119. norm_layer="LN", # "BN", "LN2D"
  1120. downsample_version: str = "v2", # "v1", "v2", "v3"
  1121. patchembed_version: str = "v1", # "v1", "v2"
  1122. use_checkpoint=False,
  1123. # =========================
  1124. posembed=False,
  1125. imgsize=224,
  1126. _SS2D=SS2D,
  1127. # =========================
  1128. **kwargs,
  1129. ):
  1130. super().__init__()
  1131. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1132. self.num_classes = num_classes
  1133. self.num_layers = len(depths)
  1134. if isinstance(dims, int):
  1135. dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
  1136. self.num_features = dims[-1]
  1137. self.dims = dims
  1138. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  1139. _NORMLAYERS = dict(
  1140. ln=nn.LayerNorm,
  1141. ln2d=LayerNorm2d,
  1142. bn=nn.BatchNorm2d,
  1143. )
  1144. _ACTLAYERS = dict(
  1145. silu=nn.SiLU,
  1146. gelu=nn.GELU,
  1147. relu=nn.ReLU,
  1148. sigmoid=nn.Sigmoid,
  1149. )
  1150. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1151. ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
  1152. mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
  1153. self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
  1154. _make_patch_embed = dict(
  1155. v1=self._make_patch_embed,
  1156. v2=self._make_patch_embed_v2,
  1157. ).get(patchembed_version, None)
  1158. self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first)
  1159. _make_downsample = dict(
  1160. v1=PatchMerging2D,
  1161. v2=self._make_downsample,
  1162. v3=self._make_downsample_v3,
  1163. none=(lambda *_, **_k: None),
  1164. ).get(downsample_version, None)
  1165. self.layers = nn.ModuleList()
  1166. for i_layer in range(self.num_layers):
  1167. downsample = _make_downsample(
  1168. self.dims[i_layer],
  1169. self.dims[i_layer + 1],
  1170. norm_layer=norm_layer,
  1171. channel_first=self.channel_first,
  1172. ) if (i_layer < self.num_layers - 1) else nn.Identity()
  1173. self.layers.append(self._make_layer(
  1174. dim = self.dims[i_layer],
  1175. drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  1176. use_checkpoint=use_checkpoint,
  1177. norm_layer=norm_layer,
  1178. downsample=downsample,
  1179. channel_first=self.channel_first,
  1180. # =================
  1181. ssm_d_state=ssm_d_state,
  1182. ssm_ratio=ssm_ratio,
  1183. ssm_dt_rank=ssm_dt_rank,
  1184. ssm_act_layer=ssm_act_layer,
  1185. ssm_conv=ssm_conv,
  1186. ssm_conv_bias=ssm_conv_bias,
  1187. ssm_drop_rate=ssm_drop_rate,
  1188. ssm_init=ssm_init,
  1189. forward_type=forward_type,
  1190. # =================
  1191. mlp_ratio=mlp_ratio,
  1192. mlp_act_layer=mlp_act_layer,
  1193. mlp_drop_rate=mlp_drop_rate,
  1194. gmlp=gmlp,
  1195. # =================
  1196. _SS2D=_SS2D,
  1197. ))
  1198. self.classifier = nn.Sequential(OrderedDict(
  1199. norm=norm_layer(self.num_features), # B,H,W,C
  1200. permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
  1201. avgpool=nn.AdaptiveAvgPool2d(1),
  1202. flatten=nn.Flatten(1),
  1203. head=nn.Linear(self.num_features, num_classes),
  1204. ))
  1205. self.apply(self._init_weights)
  1206. @staticmethod
  1207. def _pos_embed(embed_dims, patch_size, img_size):
  1208. patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
  1209. pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
  1210. trunc_normal_(pos_embed, std=0.02)
  1211. return pos_embed
  1212. def _init_weights(self, m: nn.Module):
  1213. if isinstance(m, nn.Linear):
  1214. trunc_normal_(m.weight, std=.02)
  1215. if isinstance(m, nn.Linear) and m.bias is not None:
  1216. nn.init.constant_(m.bias, 0)
  1217. elif isinstance(m, nn.LayerNorm):
  1218. nn.init.constant_(m.bias, 0)
  1219. nn.init.constant_(m.weight, 1.0)
  1220. # used in building optimizer
  1221. @torch.jit.ignore
  1222. def no_weight_decay(self):
  1223. return {"pos_embed"}
  1224. # used in building optimizer
  1225. @torch.jit.ignore
  1226. def no_weight_decay_keywords(self):
  1227. return {}
  1228. @staticmethod
  1229. def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1230. # if channel first, then Norm and Output are both channel_first
  1231. return nn.Sequential(
  1232. nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
  1233. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1234. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1235. )
  1236. @staticmethod
  1237. def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1238. # if channel first, then Norm and Output are both channel_first
  1239. stride = patch_size // 2
  1240. kernel_size = stride + 1
  1241. padding = 1
  1242. return nn.Sequential(
  1243. nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
  1244. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)),
  1245. (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
  1246. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)),
  1247. nn.GELU(),
  1248. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
  1249. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1250. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1251. )
  1252. @staticmethod
  1253. def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1254. # if channel first, then Norm and Output are both channel_first
  1255. return nn.Sequential(
  1256. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1257. nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
  1258. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1259. norm_layer(out_dim),
  1260. )
  1261. @staticmethod
  1262. def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1263. # if channel first, then Norm and Output are both channel_first
  1264. return nn.Sequential(
  1265. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1266. nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
  1267. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1268. norm_layer(out_dim),
  1269. )
  1270. @staticmethod
  1271. def _make_layer(
  1272. dim=96,
  1273. drop_path=[0.1, 0.1],
  1274. use_checkpoint=False,
  1275. norm_layer=nn.LayerNorm,
  1276. downsample=nn.Identity(),
  1277. channel_first=False,
  1278. # ===========================
  1279. ssm_d_state=16,
  1280. ssm_ratio=2.0,
  1281. ssm_dt_rank="auto",
  1282. ssm_act_layer=nn.SiLU,
  1283. ssm_conv=3,
  1284. ssm_conv_bias=True,
  1285. ssm_drop_rate=0.0,
  1286. ssm_init="v0",
  1287. forward_type="v2",
  1288. # ===========================
  1289. mlp_ratio=4.0,
  1290. mlp_act_layer=nn.GELU,
  1291. mlp_drop_rate=0.0,
  1292. gmlp=False,
  1293. # ===========================
  1294. _SS2D=SS2D,
  1295. **kwargs,
  1296. ):
  1297. # if channel first, then Norm and Output are both channel_first
  1298. depth = len(drop_path)
  1299. blocks = []
  1300. for d in range(depth):
  1301. blocks.append(VSSBlock(
  1302. hidden_dim=dim,
  1303. drop_path=drop_path[d],
  1304. norm_layer=norm_layer,
  1305. channel_first=channel_first,
  1306. ssm_d_state=ssm_d_state,
  1307. ssm_ratio=ssm_ratio,
  1308. ssm_dt_rank=ssm_dt_rank,
  1309. ssm_act_layer=ssm_act_layer,
  1310. ssm_conv=ssm_conv,
  1311. ssm_conv_bias=ssm_conv_bias,
  1312. ssm_drop_rate=ssm_drop_rate,
  1313. ssm_init=ssm_init,
  1314. forward_type=forward_type,
  1315. mlp_ratio=mlp_ratio,
  1316. mlp_act_layer=mlp_act_layer,
  1317. mlp_drop_rate=mlp_drop_rate,
  1318. gmlp=gmlp,
  1319. use_checkpoint=use_checkpoint,
  1320. _SS2D=_SS2D,
  1321. ))
  1322. return nn.Sequential(OrderedDict(
  1323. blocks=nn.Sequential(*blocks,),
  1324. downsample=downsample,
  1325. ))
  1326. def forward(self, x: torch.Tensor):
  1327. x = self.patch_embed(x)
  1328. if self.pos_embed is not None:
  1329. pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
  1330. x = x + pos_embed
  1331. for layer in self.layers:
  1332. x = layer(x)
  1333. x = self.classifier(x)
  1334. return x
  1335. def flops(self, shape=(3, 224, 224), verbose=True):
  1336. from fvcore.nn import flop_count, parameter_count
  1337. # shape = self.__input_shape__[1:]
  1338. supported_ops={
  1339. "aten::silu": None, # as relu is in _IGNORED_OPS
  1340. "aten::neg": None, # as relu is in _IGNORED_OPS
  1341. "aten::exp": None, # as relu is in _IGNORED_OPS
  1342. "aten::flip": None, # as permute is in _IGNORED_OPS
  1343. # "prim::PythonOp.CrossScan": None,
  1344. # "prim::PythonOp.CrossMerge": None,
  1345. "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
  1346. }
  1347. model = copy.deepcopy(self)
  1348. model.cuda().eval()
  1349. input = torch.randn((1, *shape), device=next(model.parameters()).device)
  1350. params = parameter_count(model)[""]
  1351. Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
  1352. del model, input
  1353. return sum(Gflops.values()) * 1e9
  1354. return f"params {params} GFLOPs {sum(Gflops.values())}"
  1355. # used to load ckpt from previous training code
  1356. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  1357. def check_name(src, state_dict: dict = state_dict, strict=False):
  1358. if strict:
  1359. if prefix + src in list(state_dict.keys()):
  1360. return True
  1361. else:
  1362. key = prefix + src
  1363. for k in list(state_dict.keys()):
  1364. if k.startswith(key):
  1365. return True
  1366. return False
  1367. def change_name(src, dst, state_dict: dict = state_dict, strict=False):
  1368. if strict:
  1369. if prefix + src in list(state_dict.keys()):
  1370. state_dict[prefix + dst] = state_dict[prefix + src]
  1371. state_dict.pop(prefix + src)
  1372. else:
  1373. key = prefix + src
  1374. for k in list(state_dict.keys()):
  1375. if k.startswith(key):
  1376. new_k = prefix + dst + k[len(key):]
  1377. state_dict[new_k] = state_dict[k]
  1378. state_dict.pop(k)
  1379. if check_name("pos_embed", strict=True):
  1380. srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
  1381. state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
  1382. change_name("patch_embed.proj", "patch_embed.0")
  1383. change_name("patch_embed.norm", "patch_embed.2")
  1384. for i in range(100):
  1385. for j in range(100):
  1386. change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
  1387. change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
  1388. change_name("norm", "classifier.norm")
  1389. change_name("head", "classifier.head")
  1390. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  1391. # compatible with openmmlab
  1392. class Backbone_VSSM(VSSM):
  1393. def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
  1394. kwargs.update(norm_layer=norm_layer)
  1395. super().__init__(**kwargs)
  1396. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1397. _NORMLAYERS = dict(
  1398. ln=nn.LayerNorm,
  1399. ln2d=LayerNorm2d,
  1400. bn=nn.BatchNorm2d,
  1401. )
  1402. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1403. self.out_indices = out_indices
  1404. for i in out_indices:
  1405. layer = norm_layer(self.dims[i])
  1406. layer_name = f'outnorm{i}'
  1407. self.add_module(layer_name, layer)
  1408. del self.classifier
  1409. self.load_pretrained(pretrained)
  1410. def load_pretrained(self, ckpt=None, key="model"):
  1411. if ckpt is None:
  1412. return
  1413. try:
  1414. _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
  1415. print(f"Successfully load ckpt {ckpt}")
  1416. incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
  1417. print(incompatibleKeys)
  1418. except Exception as e:
  1419. print(f"Failed loading checkpoint form {ckpt}: {e}")
  1420. def forward(self, x):
  1421. def layer_forward(l, x):
  1422. x = l.blocks(x)
  1423. y = l.downsample(x)
  1424. return x, y
  1425. x = self.patch_embed(x)
  1426. outs = []
  1427. for i, layer in enumerate(self.layers):
  1428. o, x = layer_forward(layer, x) # (B, H, W, C)
  1429. if i in self.out_indices:
  1430. norm_layer = getattr(self, f'outnorm{i}')
  1431. out = norm_layer(o)
  1432. if not self.channel_first:
  1433. out = out.permute(0, 3, 1, 2)
  1434. outs.append(out.contiguous())
  1435. if len(self.out_indices) == 0:
  1436. return x
  1437. return outs
  1438. # =====================================================
  1439. def vanilla_vmamba_tiny():
  1440. return VSSM(
  1441. depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
  1442. patch_size=4, in_chans=3, num_classes=1000,
  1443. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1444. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1445. ssm_init="v0", forward_type="v0",
  1446. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1447. patch_norm=True, norm_layer="ln",
  1448. downsample_version="v1", patchembed_version="v1",
  1449. use_checkpoint=False, posembed=False, imgsize=224,
  1450. )
  1451. def vanilla_vmamba_small():
  1452. return VSSM(
  1453. depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
  1454. patch_size=4, in_chans=3, num_classes=1000,
  1455. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1456. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1457. ssm_init="v0", forward_type="v0",
  1458. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1459. patch_norm=True, norm_layer="ln",
  1460. downsample_version="v1", patchembed_version="v1",
  1461. use_checkpoint=False, posembed=False, imgsize=224,
  1462. )
  1463. def vanilla_vmamba_base():
  1464. return VSSM(
  1465. depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
  1466. patch_size=4, in_chans=3, num_classes=1000,
  1467. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1468. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1469. ssm_init="v0", forward_type="v0",
  1470. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1471. patch_norm=True, norm_layer="ln",
  1472. downsample_version="v1", patchembed_version="v1",
  1473. use_checkpoint=False, posembed=False, imgsize=224,
  1474. )
  1475. # =====================================================
  1476. def vmamba_tiny_s2l5(channel_first=True):
  1477. return VSSM(
  1478. depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
  1479. patch_size=4, in_chans=3, num_classes=1000,
  1480. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1481. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1482. ssm_init="v0", forward_type="v05_noz",
  1483. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1484. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1485. downsample_version="v3", patchembed_version="v2",
  1486. use_checkpoint=False, posembed=False, imgsize=224,
  1487. )
  1488. def vmamba_small_s2l15(channel_first=True):
  1489. return VSSM(
  1490. depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
  1491. patch_size=4, in_chans=3, num_classes=1000,
  1492. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1493. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1494. ssm_init="v0", forward_type="v05_noz",
  1495. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1496. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1497. downsample_version="v3", patchembed_version="v2",
  1498. use_checkpoint=False, posembed=False, imgsize=224,
  1499. )
  1500. def vmamba_base_s2l15(channel_first=True):
  1501. return VSSM(
  1502. depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
  1503. patch_size=4, in_chans=3, num_classes=1000,
  1504. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1505. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1506. ssm_init="v0", forward_type="v05_noz",
  1507. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1508. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1509. downsample_version="v3", patchembed_version="v2",
  1510. use_checkpoint=False, posembed=False, imgsize=224,
  1511. )
  1512. # =====================================================
  1513. def vmamba_tiny_s1l8(channel_first=True):
  1514. return VSSM(
  1515. depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
  1516. patch_size=4, in_chans=3, num_classes=1000,
  1517. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1518. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1519. ssm_init="v0", forward_type="v05_noz",
  1520. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1521. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1522. downsample_version="v3", patchembed_version="v2",
  1523. use_checkpoint=False, posembed=False, imgsize=224,
  1524. )
  1525. def vmamba_small_s1l20(channel_first=True):
  1526. return VSSM(
  1527. depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
  1528. patch_size=4, in_chans=3, num_classes=1000,
  1529. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1530. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1531. ssm_init="v0", forward_type="v05_noz",
  1532. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1533. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1534. downsample_version="v3", patchembed_version="v2",
  1535. use_checkpoint=False, posembed=False, imgsize=224,
  1536. )
  1537. def vmamba_base_s1l20(channel_first=True):
  1538. return VSSM(
  1539. depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
  1540. patch_size=4, in_chans=3, num_classes=1000,
  1541. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1542. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1543. ssm_init="v0", forward_type="v05_noz",
  1544. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1545. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1546. downsample_version="v3", patchembed_version="v2",
  1547. use_checkpoint=False, posembed=False, imgsize=224,
  1548. )
  1549. # mamba2 support =====================================================
  1550. # FLOPS count do not work now for mamba2!
  1551. def vmamba_tiny_m2():
  1552. return VSSM(
  1553. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1554. patch_size=4, in_chans=3, num_classes=1000,
  1555. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1556. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1557. ssm_init="v2", forward_type="m0_noz",
  1558. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1559. patch_norm=True, norm_layer="ln",
  1560. downsample_version="v3", patchembed_version="v2",
  1561. use_checkpoint=False, posembed=False, imgsize=224,
  1562. )
  1563. def vmamba_small_m2():
  1564. return VSSM(
  1565. depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3,
  1566. patch_size=4, in_chans=3, num_classes=1000,
  1567. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1568. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1569. ssm_init="v2", forward_type="m0_noz",
  1570. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1571. patch_norm=True, norm_layer="ln",
  1572. downsample_version="v3", patchembed_version="v2",
  1573. use_checkpoint=False, posembed=False, imgsize=224,
  1574. )
  1575. def vmamba_base_m2():
  1576. return VSSM(
  1577. depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3,
  1578. patch_size=4, in_chans=3, num_classes=1000,
  1579. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1580. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1581. ssm_init="v2", forward_type="m0_noz",
  1582. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1583. patch_norm=True, norm_layer="ln",
  1584. downsample_version="v3", patchembed_version="v2",
  1585. use_checkpoint=False, posembed=False, imgsize=224,
  1586. )
  1587. if __name__ == "__main__":
  1588. from fvcore.nn import parameter_count
  1589. model_ref = vmamba_tiny_s1l8()
  1590. model = VSSM(
  1591. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1592. patch_size=4, in_chans=3, num_classes=1000,
  1593. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1594. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1595. ssm_init="v2", forward_type="m0_noz",
  1596. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1597. patch_norm=True, norm_layer="ln",
  1598. downsample_version="v3", patchembed_version="v2",
  1599. use_checkpoint=False, posembed=False, imgsize=224,
  1600. )
  1601. print(parameter_count(model)[""])
  1602. print(model.flops()) # wrong
  1603. model.cuda().train()
  1604. model_ref.cuda().train()
  1605. def bench(model):
  1606. import time
  1607. inp = torch.randn((128, 3, 224, 224)).cuda()
  1608. for _ in range(30):
  1609. model(inp)
  1610. torch.cuda.synchronize()
  1611. tim = time.time()
  1612. for _ in range(30):
  1613. model(inp)
  1614. torch.cuda.synchronize()
  1615. tim1 = time.time() - tim
  1616. for _ in range(30):
  1617. model(inp).sum().backward()
  1618. torch.cuda.synchronize()
  1619. tim = time.time()
  1620. for _ in range(30):
  1621. model(inp).sum().backward()
  1622. torch.cuda.synchronize()
  1623. tim2 = time.time() - tim
  1624. return tim1 / 30, tim2 / 30
  1625. print(bench(model_ref))
  1626. print(bench(model))
  1627. breakpoint()