vmamba.py 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850
  1. import os
  2. import time
  3. import math
  4. import copy
  5. from functools import partial
  6. from typing import Optional, Callable, Any
  7. from collections import OrderedDict
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. import torch.utils.checkpoint as checkpoint
  12. from timm.models.layers import DropPath, trunc_normal_
  13. from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
  14. DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
  15. # train speed is slower after enabling this opts.
  16. # torch.backends.cudnn.enabled = True
  17. # torch.backends.cudnn.benchmark = True
  18. # torch.backends.cudnn.deterministic = True
  19. try:
  20. from .csm_triton import cross_scan_fn, cross_merge_fn
  21. except:
  22. from csm_triton import cross_scan_fn, cross_merge_fn
  23. try:
  24. from .csms6s import selective_scan_fn, selective_scan_flop_jit
  25. except:
  26. from csms6s import selective_scan_fn, selective_scan_flop_jit
  27. # FLOPs counter not prepared fro mamba2
  28. try:
  29. from .mamba2.ssd_minimal import selective_scan_chunk_fn
  30. except:
  31. from mamba2.ssd_minimal import selective_scan_chunk_fn
  32. # =====================================================
  33. # we have this class as linear and conv init differ from each other
  34. # this function enable loading from both conv2d or linear
  35. class Linear2d(nn.Linear):
  36. def forward(self, x: torch.Tensor):
  37. # B, C, H, W = x.shape
  38. return F.conv2d(x, self.weight[:, :, None, None], self.bias)
  39. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  40. state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape)
  41. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  42. class LayerNorm2d(nn.LayerNorm):
  43. def forward(self, x: torch.Tensor):
  44. x = x.permute(0, 2, 3, 1)
  45. x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  46. x = x.permute(0, 3, 1, 2)
  47. return x
  48. class PatchMerging2D(nn.Module):
  49. def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False):
  50. super().__init__()
  51. self.dim = dim
  52. Linear = Linear2d if channel_first else nn.Linear
  53. self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last
  54. self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
  55. self.norm = norm_layer(4 * dim)
  56. @staticmethod
  57. def _patch_merging_pad_channel_last(x: torch.Tensor):
  58. H, W, _ = x.shape[-3:]
  59. if (W % 2 != 0) or (H % 2 != 0):
  60. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  61. x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
  62. x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
  63. x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
  64. x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
  65. x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
  66. return x
  67. @staticmethod
  68. def _patch_merging_pad_channel_first(x: torch.Tensor):
  69. H, W = x.shape[-2:]
  70. if (W % 2 != 0) or (H % 2 != 0):
  71. x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
  72. x0 = x[..., 0::2, 0::2] # ... H/2 W/2
  73. x1 = x[..., 1::2, 0::2] # ... H/2 W/2
  74. x2 = x[..., 0::2, 1::2] # ... H/2 W/2
  75. x3 = x[..., 1::2, 1::2] # ... H/2 W/2
  76. x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C
  77. return x
  78. def forward(self, x):
  79. x = self._patch_merging_pad(x)
  80. x = self.norm(x)
  81. x = self.reduction(x)
  82. return x
  83. class Permute(nn.Module):
  84. def __init__(self, *args):
  85. super().__init__()
  86. self.args = args
  87. def forward(self, x: torch.Tensor):
  88. return x.permute(*self.args)
  89. class Mlp(nn.Module):
  90. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  91. super().__init__()
  92. out_features = out_features or in_features
  93. hidden_features = hidden_features or in_features
  94. Linear = Linear2d if channels_first else nn.Linear
  95. self.fc1 = Linear(in_features, hidden_features)
  96. self.act = act_layer()
  97. self.fc2 = Linear(hidden_features, out_features)
  98. self.drop = nn.Dropout(drop)
  99. def forward(self, x):
  100. x = self.fc1(x)
  101. x = self.act(x)
  102. x = self.drop(x)
  103. x = self.fc2(x)
  104. x = self.drop(x)
  105. return x
  106. class gMlp(nn.Module):
  107. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):
  108. super().__init__()
  109. self.channel_first = channels_first
  110. out_features = out_features or in_features
  111. hidden_features = hidden_features or in_features
  112. Linear = Linear2d if channels_first else nn.Linear
  113. self.fc1 = Linear(in_features, 2 * hidden_features)
  114. self.act = act_layer()
  115. self.fc2 = Linear(hidden_features, out_features)
  116. self.drop = nn.Dropout(drop)
  117. def forward(self, x: torch.Tensor):
  118. x = self.fc1(x)
  119. x, z = x.chunk(2, dim=(1 if self.channel_first else -1))
  120. x = self.fc2(x * self.act(z))
  121. x = self.drop(x)
  122. return x
  123. class SoftmaxSpatial(nn.Softmax):
  124. def forward(self, x: torch.Tensor):
  125. if self.dim == -1:
  126. B, C, H, W = x.shape
  127. return super().forward(x.view(B, C, -1)).view(B, C, H, W)
  128. elif self.dim == 1:
  129. B, H, W, C = x.shape
  130. return super().forward(x.view(B, -1, C)).view(B, H, W, C)
  131. else:
  132. raise NotImplementedError
  133. # =====================================================
  134. class mamba_init:
  135. @staticmethod
  136. def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
  137. dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
  138. # Initialize special dt projection to preserve variance at initialization
  139. dt_init_std = dt_rank**-0.5 * dt_scale
  140. if dt_init == "constant":
  141. nn.init.constant_(dt_proj.weight, dt_init_std)
  142. elif dt_init == "random":
  143. nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
  144. else:
  145. raise NotImplementedError
  146. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  147. dt = torch.exp(
  148. torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
  149. + math.log(dt_min)
  150. ).clamp(min=dt_init_floor)
  151. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  152. inv_dt = dt + torch.log(-torch.expm1(-dt))
  153. with torch.no_grad():
  154. dt_proj.bias.copy_(inv_dt)
  155. # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  156. # dt_proj.bias._no_reinit = True
  157. return dt_proj
  158. @staticmethod
  159. def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
  160. # S4D real initialization
  161. A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
  162. A_log = torch.log(A) # Keep A_log in fp32
  163. if copies > 0:
  164. A_log = A_log[None].repeat(copies, 1, 1).contiguous()
  165. if merge:
  166. A_log = A_log.flatten(0, 1)
  167. A_log = nn.Parameter(A_log)
  168. A_log._no_weight_decay = True
  169. return A_log
  170. @staticmethod
  171. def D_init(d_inner, copies=-1, device=None, merge=True):
  172. # D "skip" parameter
  173. D = torch.ones(d_inner, device=device)
  174. if copies > 0:
  175. D = D[None].repeat(copies, 1).contiguous()
  176. if merge:
  177. D = D.flatten(0, 1)
  178. D = nn.Parameter(D) # Keep in fp32
  179. D._no_weight_decay = True
  180. return D
  181. @classmethod
  182. def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
  183. # dt proj ============================
  184. dt_projs = [
  185. cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
  186. for _ in range(k_group)
  187. ]
  188. dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
  189. dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
  190. del dt_projs
  191. # A, D =======================================
  192. A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
  193. Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
  194. return A_logs, Ds, dt_projs_weight, dt_projs_bias
  195. # support: v0, v0seq
  196. class SS2Dv0:
  197. def __initv0__(
  198. self,
  199. # basic dims ===========
  200. d_model=96,
  201. d_state=16,
  202. ssm_ratio=2.0,
  203. dt_rank="auto",
  204. # ======================
  205. dropout=0.0,
  206. # ======================
  207. seq=False,
  208. force_fp32=True,
  209. **kwargs,
  210. ):
  211. if "channel_first" in kwargs:
  212. assert not kwargs["channel_first"]
  213. act_layer = nn.SiLU
  214. dt_min = 0.001
  215. dt_max = 0.1
  216. dt_init = "random"
  217. dt_scale = 1.0
  218. dt_init_floor = 1e-4
  219. bias = False
  220. conv_bias = True
  221. d_conv = 3
  222. k_group = 4
  223. factory_kwargs = {"device": None, "dtype": None}
  224. super().__init__()
  225. d_inner = int(ssm_ratio * d_model)
  226. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  227. self.forward = self.forwardv0
  228. if seq:
  229. self.forward = partial(self.forwardv0, seq=True)
  230. if not force_fp32:
  231. self.forward = partial(self.forwardv0, force_fp32=False)
  232. # in proj ============================
  233. self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
  234. self.act: nn.Module = act_layer()
  235. self.conv2d = nn.Conv2d(
  236. in_channels=d_inner,
  237. out_channels=d_inner,
  238. groups=d_inner,
  239. bias=conv_bias,
  240. kernel_size=d_conv,
  241. padding=(d_conv - 1) // 2,
  242. **factory_kwargs,
  243. )
  244. # x proj ============================
  245. self.x_proj = [
  246. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  247. for _ in range(k_group)
  248. ]
  249. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  250. del self.x_proj
  251. # dt proj, A, D ============================
  252. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  253. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  254. )
  255. # out proj =======================================
  256. self.out_norm = nn.LayerNorm(d_inner)
  257. self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
  258. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  259. def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
  260. x = self.in_proj(x)
  261. x, z = x.chunk(2, dim=-1) # (b, h, w, d)
  262. z = self.act(z)
  263. x = x.permute(0, 3, 1, 2).contiguous()
  264. x = self.conv2d(x) # (b, d, h, w)
  265. x = self.act(x)
  266. selective_scan = partial(selective_scan_fn, backend="mamba")
  267. B, D, H, W = x.shape
  268. D, N = self.A_logs.shape
  269. K, D, R = self.dt_projs_weight.shape
  270. L = H * W
  271. x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
  272. xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
  273. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  274. if hasattr(self, "x_proj_bias"):
  275. x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
  276. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  277. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  278. xs = xs.view(B, -1, L) # (b, k * d, l)
  279. dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
  280. Bs = Bs.contiguous() # (b, k, d_state, l)
  281. Cs = Cs.contiguous() # (b, k, d_state, l)
  282. As = -self.A_logs.float().exp() # (k * d, d_state)
  283. Ds = self.Ds.float() # (k * d)
  284. dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
  285. # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
  286. # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
  287. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  288. if force_fp32:
  289. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  290. if seq:
  291. out_y = []
  292. for i in range(4):
  293. yi = selective_scan(
  294. xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
  295. As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
  296. delta_bias=dt_projs_bias.view(K, -1)[i],
  297. delta_softplus=True,
  298. ).view(B, -1, L)
  299. out_y.append(yi)
  300. out_y = torch.stack(out_y, dim=1)
  301. else:
  302. out_y = selective_scan(
  303. xs, dts,
  304. As, Bs, Cs, Ds,
  305. delta_bias=dt_projs_bias,
  306. delta_softplus=True,
  307. ).view(B, K, -1, L)
  308. assert out_y.dtype == torch.float
  309. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  310. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  311. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  312. y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  313. y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
  314. y = self.out_norm(y).view(B, H, W, -1)
  315. y = y * z
  316. out = self.dropout(self.out_proj(y))
  317. return out
  318. # support: v01-v05; v051d,v052d,v052dc;
  319. # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
  320. # history support: v2,v3;v31d,v32d,v32dc;
  321. class SS2Dv2:
  322. def __initv2__(
  323. self,
  324. # basic dims ===========
  325. d_model=96,
  326. d_state=16,
  327. ssm_ratio=2.0,
  328. dt_rank="auto",
  329. act_layer=nn.SiLU,
  330. # dwconv ===============
  331. d_conv=3, # < 2 means no conv
  332. conv_bias=True,
  333. # ======================
  334. dropout=0.0,
  335. bias=False,
  336. # dt init ==============
  337. dt_min=0.001,
  338. dt_max=0.1,
  339. dt_init="random",
  340. dt_scale=1.0,
  341. dt_init_floor=1e-4,
  342. initialize="v0",
  343. # ======================
  344. forward_type="v2",
  345. channel_first=False,
  346. # ======================
  347. **kwargs,
  348. ):
  349. factory_kwargs = {"device": None, "dtype": None}
  350. super().__init__()
  351. self.k_group = 4
  352. self.d_model = int(d_model)
  353. self.d_state = int(d_state)
  354. self.d_inner = int(ssm_ratio * d_model)
  355. self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
  356. self.channel_first = channel_first
  357. self.with_dconv = d_conv > 1
  358. Linear = Linear2d if channel_first else nn.Linear
  359. self.forward = self.forwardv2
  360. # tags for forward_type ==============================
  361. checkpostfix = self.checkpostfix
  362. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  363. self.oact, forward_type = checkpostfix("_oact", forward_type)
  364. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  365. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  366. self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
  367. # forward_type debug =======================================
  368. FORWARD_TYPES = dict(
  369. v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
  370. v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
  371. v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
  372. v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
  373. v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
  374. # ===============================
  375. v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
  376. v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
  377. v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
  378. v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
  379. # ===============================
  380. v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
  381. v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
  382. )
  383. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  384. # in proj =======================================
  385. d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
  386. self.in_proj = Linear(self.d_model, d_proj, bias=bias)
  387. self.act: nn.Module = act_layer()
  388. # conv =======================================
  389. if self.with_dconv:
  390. self.conv2d = nn.Conv2d(
  391. in_channels=self.d_inner,
  392. out_channels=self.d_inner,
  393. groups=self.d_inner,
  394. bias=conv_bias,
  395. kernel_size=d_conv,
  396. padding=(d_conv - 1) // 2,
  397. **factory_kwargs,
  398. )
  399. # x proj ============================
  400. self.x_proj = [
  401. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
  402. for _ in range(self.k_group)
  403. ]
  404. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  405. del self.x_proj
  406. # out proj =======================================
  407. self.out_act = nn.GELU() if self.oact else nn.Identity()
  408. self.out_proj = Linear(self.d_inner, self.d_model, bias=bias)
  409. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  410. if initialize in ["v0"]:
  411. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  412. self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
  413. )
  414. elif initialize in ["v1"]:
  415. # simple init dt_projs, A_logs, Ds
  416. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  417. self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  418. self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
  419. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
  420. elif initialize in ["v2"]:
  421. # simple init dt_projs, A_logs, Ds
  422. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  423. self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  424. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
  425. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
  426. def forward_corev2(
  427. self,
  428. x: torch.Tensor=None,
  429. # ==============================
  430. force_fp32=False, # True: input fp32
  431. # ==============================
  432. ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
  433. no_einsum=False, # replace einsum with linear or conv1d to raise throughput
  434. # ==============================
  435. selective_scan_backend = None,
  436. # ==============================
  437. scan_mode = "cross2d",
  438. scan_force_torch = False,
  439. # ==============================
  440. **kwargs,
  441. ):
  442. assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
  443. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
  444. assert isinstance(_scan_mode, int)
  445. delta_softplus = True
  446. out_norm = self.out_norm
  447. channel_first = self.channel_first
  448. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  449. B, D, H, W = x.shape
  450. N = self.d_state
  451. K, D, R = self.k_group, self.d_inner, self.dt_rank
  452. L = H * W
  453. def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
  454. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
  455. if _scan_mode == -1:
  456. x_proj_bias = getattr(self, "x_proj_bias", None)
  457. def scan_rowcol(
  458. x: torch.Tensor,
  459. proj_weight: torch.Tensor,
  460. proj_bias: torch.Tensor,
  461. dt_weight: torch.Tensor,
  462. dt_bias: torch.Tensor, # (2*c)
  463. _As: torch.Tensor, # As = -torch.exp(A_logs.to(torch.float))[:2,] # (2*c, d_state)
  464. _Ds: torch.Tensor,
  465. width = True,
  466. ):
  467. # x: (B, D, H, W)
  468. # proj_weight: (2 * D, (R+N+N))
  469. XB, XD, XH, XW = x.shape
  470. if width:
  471. _B, _D, _L = XB * XH, XD, XW
  472. xs = x.permute(0, 2, 1, 3).contiguous()
  473. else:
  474. _B, _D, _L = XB * XW, XD, XH
  475. xs = x.permute(0, 3, 1, 2).contiguous()
  476. xs = torch.stack([xs, xs.flip(dims=[-1])], dim=2) # (B, H, 2, D, W)
  477. if no_einsum:
  478. x_dbl = F.conv1d(xs.view(_B, -1, _L), proj_weight.view(-1, _D, 1), bias=(proj_bias.view(-1) if proj_bias is not None else None), groups=2)
  479. dts, Bs, Cs = torch.split(x_dbl.view(_B, 2, -1, _L), [R, N, N], dim=2)
  480. dts = F.conv1d(dts.contiguous().view(_B, -1, _L), dt_weight.view(2 * _D, -1, 1), groups=2)
  481. else:
  482. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, proj_weight)
  483. if x_proj_bias is not None:
  484. x_dbl = x_dbl + x_proj_bias.view(1, 2, -1, 1)
  485. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  486. dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_weight)
  487. xs = xs.view(_B, -1, _L)
  488. dts = dts.contiguous().view(_B, -1, _L)
  489. As = _As.view(-1, N).to(torch.float)
  490. Bs = Bs.contiguous().view(_B, 2, N, _L)
  491. Cs = Cs.contiguous().view(_B, 2, N, _L)
  492. Ds = _Ds.view(-1)
  493. delta_bias = dt_bias.view(-1).to(torch.float)
  494. if force_fp32:
  495. xs = xs.to(torch.float)
  496. dts = dts.to(xs.dtype)
  497. Bs = Bs.to(xs.dtype)
  498. Cs = Cs.to(xs.dtype)
  499. ys: torch.Tensor = selective_scan(
  500. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  501. ).view(_B, 2, -1, _L)
  502. return ys
  503. As = -self.A_logs.to(torch.float).exp().view(4, -1, N)
  504. x = F.layer_norm(x.permute(0, 2, 3, 1), normalized_shape=(int(x.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  505. y_row = scan_rowcol(
  506. x,
  507. proj_weight = self.x_proj_weight.view(4, -1, D)[:2].contiguous(),
  508. proj_bias = (x_proj_bias.view(4, -1)[:2].contiguous() if x_proj_bias is not None else None),
  509. dt_weight = self.dt_projs_weight.view(4, D, -1)[:2].contiguous(),
  510. dt_bias = (self.dt_projs_bias.view(4, -1)[:2].contiguous() if self.dt_projs_bias is not None else None),
  511. _As = As[:2].contiguous().view(-1, N),
  512. _Ds = self.Ds.view(4, -1)[:2].contiguous().view(-1),
  513. width=True,
  514. ).view(B, H, 2, -1, W).sum(dim=2).permute(0, 2, 1, 3) # (B,C,H,W)
  515. y_row = F.layer_norm(y_row.permute(0, 2, 3, 1), normalized_shape=(int(y_row.shape[1]),)).permute(0, 3, 1, 2).contiguous() # added0510 to avoid nan
  516. y_col = scan_rowcol(
  517. y_row,
  518. proj_weight = self.x_proj_weight.view(4, -1, D)[2:].contiguous().to(y_row.dtype),
  519. proj_bias = (x_proj_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if x_proj_bias is not None else None),
  520. dt_weight = self.dt_projs_weight.view(4, D, -1)[2:].contiguous().to(y_row.dtype),
  521. dt_bias = (self.dt_projs_bias.view(4, -1)[2:].contiguous().to(y_row.dtype) if self.dt_projs_bias is not None else None),
  522. _As = As[2:].contiguous().view(-1, N),
  523. _Ds = self.Ds.view(4, -1)[2:].contiguous().view(-1),
  524. width=False,
  525. ).view(B, W, 2, -1, H).sum(dim=2).permute(0, 2, 3, 1)
  526. y = y_col
  527. else:
  528. x_proj_bias = getattr(self, "x_proj_bias", None)
  529. xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  530. if no_einsum:
  531. x_dbl = F.conv1d(xs.view(B, -1, L), self.x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K)
  532. dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
  533. if hasattr(self, "dt_projs_weight"):
  534. dts = F.conv1d(dts.contiguous().view(B, -1, L), self.dt_projs_weight.view(K * D, -1, 1), groups=K)
  535. else:
  536. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  537. if x_proj_bias is not None:
  538. x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
  539. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  540. if hasattr(self, "dt_projs_weight"):
  541. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  542. xs = xs.view(B, -1, L)
  543. dts = dts.contiguous().view(B, -1, L)
  544. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  545. Ds = self.Ds.to(torch.float) # (K * c)
  546. Bs = Bs.contiguous().view(B, K, N, L)
  547. Cs = Cs.contiguous().view(B, K, N, L)
  548. delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
  549. if force_fp32:
  550. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  551. ys: torch.Tensor = selective_scan(
  552. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  553. ).view(B, K, -1, H, W)
  554. y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  555. if getattr(self, "__DEBUG__", False):
  556. setattr(self, "__data__", dict(
  557. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  558. us=xs, dts=dts, delta_bias=delta_bias,
  559. ys=ys, y=y, H=H, W=W,
  560. ))
  561. y = y.view(B, -1, H, W)
  562. if not channel_first:
  563. y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C)
  564. y = out_norm(y)
  565. return y.to(x.dtype)
  566. def forwardv2(self, x: torch.Tensor, **kwargs):
  567. x = self.in_proj(x)
  568. if not self.disable_z:
  569. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  570. if not self.disable_z_act:
  571. z = self.act(z)
  572. if not self.channel_first:
  573. x = x.permute(0, 3, 1, 2).contiguous()
  574. if self.with_dconv:
  575. x = self.conv2d(x) # (b, d, h, w)
  576. x = self.act(x)
  577. y = self.forward_core(x)
  578. y = self.out_act(y)
  579. if not self.disable_z:
  580. y = y * z
  581. out = self.dropout(self.out_proj(y))
  582. return out
  583. @staticmethod
  584. def get_outnorm(forward_type="", d_inner=192, channel_first=True):
  585. def checkpostfix(tag, value):
  586. ret = value[-len(tag):] == tag
  587. if ret:
  588. value = value[:-len(tag)]
  589. return ret, value
  590. LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm
  591. out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
  592. out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
  593. out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
  594. out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
  595. out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
  596. out_norm = nn.Identity()
  597. if out_norm_none:
  598. out_norm = nn.Identity()
  599. elif out_norm_cnorm:
  600. out_norm = nn.Sequential(
  601. LayerNorm(d_inner),
  602. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  603. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  604. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  605. )
  606. elif out_norm_dwconv3:
  607. out_norm = nn.Sequential(
  608. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  609. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  610. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  611. )
  612. elif out_norm_softmax:
  613. out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
  614. elif out_norm_sigmoid:
  615. out_norm = nn.Sigmoid()
  616. else:
  617. out_norm = LayerNorm(d_inner)
  618. return out_norm, forward_type
  619. @staticmethod
  620. def checkpostfix(tag, value):
  621. ret = value[-len(tag):] == tag
  622. if ret:
  623. value = value[:-len(tag)]
  624. return ret, value
  625. # support: xv1a,xv2a,xv3a;
  626. # postfix: _cpos;_ocov;_ocov2;_ca,_ca1;_act;_mul;_onsigmoid,_onsoftmax,_ondwconv3,_onnone;
  627. class SS2Dv3:
  628. def __initxv__(
  629. self,
  630. # basic dims ===========
  631. d_model=96,
  632. d_state=16,
  633. ssm_ratio=2.0,
  634. dt_rank="auto",
  635. # dwconv ===============
  636. d_conv=3, # < 2 means no conv
  637. conv_bias=True,
  638. # ======================
  639. dropout=0.0,
  640. bias=False,
  641. # dt init ==============
  642. dt_min=0.001,
  643. dt_max=0.1,
  644. dt_init="random",
  645. dt_scale=1.0,
  646. dt_init_floor=1e-4,
  647. initialize="v0",
  648. # ======================
  649. forward_type="v2",
  650. channel_first=False,
  651. # ======================
  652. **kwargs,
  653. ):
  654. super().__init__()
  655. d_inner = int(ssm_ratio * d_model)
  656. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  657. self.channel_first = channel_first
  658. self.d_state = d_state
  659. self.dt_rank = dt_rank
  660. self.d_inner = d_inner
  661. k_group = 4
  662. self.with_dconv = d_conv > 1
  663. Linear = Linear2d if channel_first else nn.Linear
  664. self.forward = self.forwardxv
  665. # tags for forward_type ==============================
  666. checkpostfix = SS2Dv2.checkpostfix
  667. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, channel_first)
  668. self.omul, forward_type = checkpostfix("_mul", forward_type)
  669. self.oact, forward_type = checkpostfix("_act", forward_type)
  670. self.f_omul = nn.Identity() if self.omul else None
  671. self.out_act = nn.GELU() if self.oact else nn.Identity()
  672. mode = forward_type[:4]
  673. assert mode in ["xv1a", "xv2a", "xv3a"]
  674. self.forward = partial(self.forwardxv, mode=mode)
  675. self.dts_dim = dict(xv1a=self.dt_rank, xv2a=self.d_inner, xv3a=4 * self.dt_rank)[mode]
  676. d_inner_all = d_inner + self.dts_dim + 8 * d_state
  677. self.in_proj = Linear(d_model, d_inner_all, bias=bias)
  678. # conv =======================================
  679. self.cpos = False
  680. self.iconv = False
  681. self.oconv = False
  682. self.oconv2 = False
  683. if self.with_dconv:
  684. cact, forward_type = checkpostfix("_ca", forward_type)
  685. cact1, forward_type = checkpostfix("_ca1", forward_type)
  686. self.cact = nn.SiLU() if cact else nn.Identity()
  687. self.cact = nn.GELU() if cact1 else self.cact
  688. self.oconv2, forward_type = checkpostfix("_ocov2", forward_type)
  689. self.oconv, forward_type = checkpostfix("_ocov", forward_type)
  690. self.cpos, forward_type = checkpostfix("_cpos", forward_type)
  691. self.iconv = (not self.oconv) and (not self.oconv2)
  692. if self.iconv:
  693. self.conv2d = nn.Conv2d(
  694. in_channels=d_model,
  695. out_channels=d_model,
  696. groups=d_model,
  697. bias=conv_bias,
  698. kernel_size=d_conv,
  699. padding=(d_conv - 1) // 2,
  700. )
  701. if self.oconv:
  702. self.oconv2d = nn.Conv2d(
  703. in_channels=d_inner,
  704. out_channels=d_inner,
  705. groups=d_inner,
  706. bias=conv_bias,
  707. kernel_size=d_conv,
  708. padding=(d_conv - 1) // 2,
  709. )
  710. if self.oconv2:
  711. self.conv2d = nn.Conv2d(
  712. in_channels=d_inner_all,
  713. out_channels=d_inner_all,
  714. groups=d_inner_all,
  715. bias=conv_bias,
  716. kernel_size=d_conv,
  717. padding=(d_conv - 1) // 2,
  718. )
  719. # out proj =======================================
  720. self.out_proj = Linear(d_inner, d_model, bias=bias)
  721. self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
  722. if initialize in ["v0"]:
  723. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  724. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  725. )
  726. elif initialize in ["v1"]:
  727. # simple init dt_projs, A_logs, Ds
  728. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  729. self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  730. self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
  731. self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
  732. elif initialize in ["v2"]:
  733. # simple init dt_projs, A_logs, Ds
  734. self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
  735. self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  736. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((k_group, d_inner, dt_rank)))
  737. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, d_inner)))
  738. if forward_type.startswith("xv2"):
  739. del self.dt_projs_weight
  740. self.dt_projs_weight = None
  741. def forwardxv(self, x: torch.Tensor, **kwargs):
  742. B, (H, W) = x.shape[0], (x.shape[2:4] if self.channel_first else x.shape[1:3])
  743. L = H * W
  744. force_fp32 = False
  745. delta_softplus = True
  746. out_norm = self.out_norm
  747. to_dtype = True
  748. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  749. def selective_scan(u, delta, A, B, C, D, delta_bias, delta_softplus):
  750. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex=True, backend=None)
  751. if self.iconv:
  752. x = self.cact(self.conv2d(x)) # (b, d, h, w)
  753. elif self.cpos:
  754. x = x + self.conv2d(x) # (b, d, h, w)
  755. x = self.in_proj(x)
  756. if self.oconv2:
  757. x = self.conv2d(x) # (b, d, h, w)
  758. us, dts, Bs, Cs = x.split([self.d_inner, self.dts_dim, 4 * self.d_state, 4 * self.d_state], dim=(1 if self.channel_first else -1))
  759. _us = us
  760. # Bs, Cs = Bs.view(B, H, W, 4, -1), Cs.view(B, H, W, 4, -1)
  761. # Bs, Cs = Bs.view(B, 4, -1, H, W), Cs.view(B, 4, -1, H, W)
  762. us = cross_scan_fn(us.contiguous(), in_channel_first=self.channel_first, out_channel_first=True).view(B, -1, L)
  763. Bs = cross_scan_fn(Bs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  764. Cs = cross_scan_fn(Cs.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=True).view(B, 4, -1, L)
  765. dts = cross_scan_fn(dts.contiguous(), in_channel_first=self.channel_first, out_channel_first=True, one_by_one=(self.dts_dim == 4 * self.dt_rank)).view(B, L, -1)
  766. if self.dts_dim == self.dt_rank:
  767. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  768. elif self.dts_dim == 4 * self.dt_rank:
  769. dts = F.conv1d(dts, self.dt_projs_weight.view(4 * self.d_inner, self.dt_rank, 1), None, groups=4)
  770. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  771. Ds = self.Ds.to(torch.float) # (K * c)
  772. delta_bias = self.dt_projs_bias.view(-1).to(torch.float) # (K * c)
  773. if force_fp32:
  774. us, dts, Bs, Cs = to_fp32(us, dts, Bs, Cs)
  775. ys: torch.Tensor = selective_scan(
  776. us, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  777. ).view(B, 4, -1, H, W)
  778. y: torch.Tensor = cross_merge_fn(ys.contiguous(), in_channel_first=self.channel_first, out_channel_first=True)
  779. y = y.view(B, -1, H, W) if self.channel_first else y.view(B, H, W, -1)
  780. y = out_norm(y)
  781. if getattr(self, "__DEBUG__", False):
  782. setattr(self, "__data__", dict(
  783. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  784. us=us, dts=dts, delta_bias=delta_bias,
  785. ys=ys, y=y,
  786. ))
  787. y = (y.to(x.dtype) if to_dtype else y)
  788. y = self.out_act(y)
  789. if self.omul:
  790. y = y * _us
  791. if self.oconv:
  792. y = y + self.cact(self.oconv2d(_us))
  793. out = self.dropout(self.out_proj(y))
  794. return out
  795. # mamba2 support ================================
  796. class SS2Dm0:
  797. def __initm0__(
  798. self,
  799. # basic dims ===========
  800. d_model=96,
  801. d_state=16, # now with mamba2, dstate should be bigger...
  802. ssm_ratio=2.0,
  803. dt_rank="auto",
  804. act_layer=nn.GELU,
  805. # dwconv ===============
  806. d_conv=3, # < 2 means no conv
  807. conv_bias=True,
  808. # ======================
  809. dropout=0.0,
  810. bias=False,
  811. # dt init ==============
  812. dt_min=0.001,
  813. dt_max=0.1,
  814. dt_init="random",
  815. dt_scale=1.0,
  816. dt_init_floor=1e-4,
  817. initialize="v2",
  818. # ======================
  819. forward_type="m0",
  820. # ======================
  821. with_initial_state=False,
  822. # ======================
  823. **kwargs,
  824. ):
  825. factory_kwargs = {"device": None, "dtype": None}
  826. super().__init__()
  827. d_inner = int(ssm_ratio * d_model)
  828. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  829. assert d_inner % dt_rank == 0
  830. self.with_dconv = d_conv > 1
  831. Linear = nn.Linear
  832. self.forward = self.forwardm0
  833. # tags for forward_type ==============================
  834. checkpostfix = SS2Dv2.checkpostfix
  835. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  836. self.oact, forward_type = checkpostfix("_oact", forward_type)
  837. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  838. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  839. self.out_norm, forward_type = SS2Dv2.get_outnorm(forward_type, d_inner, False)
  840. # forward_type debug =======================================
  841. FORWARD_TYPES = dict(
  842. m0=partial(self.forward_corem0, force_fp32=False, dstate=d_state),
  843. )
  844. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  845. k_group = 4
  846. # in proj =======================================
  847. d_proj = d_inner if self.disable_z else (d_inner * 2)
  848. self.in_proj = Linear(d_model, d_proj, bias=bias)
  849. self.act: nn.Module = act_layer()
  850. # conv =======================================
  851. if self.with_dconv:
  852. self.conv2d = nn.Sequential(
  853. Permute(0, 3, 1, 2),
  854. nn.Conv2d(
  855. in_channels=d_inner,
  856. out_channels=d_inner,
  857. groups=d_inner,
  858. bias=conv_bias,
  859. kernel_size=d_conv,
  860. padding=(d_conv - 1) // 2,
  861. **factory_kwargs,
  862. ),
  863. Permute(0, 2, 3, 1),
  864. )
  865. # x proj ============================
  866. self.x_proj = [
  867. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  868. for _ in range(k_group)
  869. ]
  870. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  871. del self.x_proj
  872. # out proj =======================================
  873. self.out_act = nn.GELU() if self.oact else nn.Identity()
  874. self.out_proj = Linear(d_inner, d_model, bias=bias)
  875. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  876. if initialize in ["v1"]:
  877. # simple init dt_projs, A_logs, Ds
  878. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  879. self.A_logs = nn.Parameter(torch.randn((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  880. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((k_group, dt_rank))) # 0.1 is added in 0430
  881. elif initialize in ["v2"]:
  882. # simple init dt_projs, A_logs, Ds
  883. self.Ds = nn.Parameter(torch.ones((k_group, dt_rank, int(d_inner // dt_rank))))
  884. self.A_logs = nn.Parameter(torch.zeros((k_group, dt_rank))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  885. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((k_group, dt_rank)))
  886. # init state ============================
  887. self.initial_state = None
  888. if with_initial_state:
  889. self.initial_state = nn.Parameter(torch.zeros((1, k_group * dt_rank, int(d_inner // dt_rank), d_state)), requires_grad=False)
  890. def forward_corem0(
  891. self,
  892. x: torch.Tensor=None,
  893. # ==============================
  894. force_fp32=False, # True: input fp32
  895. chunk_size = 64,
  896. dstate = 64,
  897. # ==============================
  898. selective_scan_backend = None,
  899. scan_mode = "cross2d",
  900. scan_force_torch = False,
  901. # ==============================
  902. **kwargs,
  903. ):
  904. assert scan_mode in ["unidi", "bidi", "cross2d"]
  905. assert selective_scan_backend in [None, "triton", "torch"]
  906. x_proj_bias = getattr(self, "x_proj_bias", None)
  907. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  908. N = dstate
  909. B, H, W, RD = x.shape
  910. K, R = self.A_logs.shape
  911. K, R, D = self.Ds.shape
  912. assert RD == R * D
  913. L = H * W
  914. KR = K * R
  915. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=3)[scan_mode]
  916. initial_state = None
  917. if self.initial_state is not None:
  918. assert self.initial_state.shape[-1] == dstate
  919. initial_state = self.initial_state.detach().repeat(B, 1, 1, 1)
  920. xs = cross_scan_fn(x.view(B, H, W, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch) # (B, H, W, 4, D)
  921. x_dbl = torch.einsum("b l k d, k c d -> b l k c", xs, self.x_proj_weight)
  922. if x_proj_bias is not None:
  923. x_dbl = x_dbl + x_proj_bias.view(1, -1, K, 1)
  924. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=3)
  925. xs = xs.contiguous().view(B, L, KR, D)
  926. dts = dts.contiguous().view(B, L, KR)
  927. Bs = Bs.contiguous().view(B, L, K, N)
  928. Cs = Cs.contiguous().view(B, L, K, N)
  929. if force_fp32:
  930. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  931. As = -self.A_logs.to(torch.float).exp().view(KR)
  932. Ds = self.Ds.to(torch.float).view(KR, D)
  933. dt_bias = self.dt_projs_bias.view(KR)
  934. if force_fp32:
  935. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  936. ys, final_state = selective_scan_chunk_fn(
  937. xs, dts, As, Bs, Cs, chunk_size=chunk_size, D=Ds, dt_bias=dt_bias,
  938. initial_states=initial_state, dt_softplus=True, return_final_states=True,
  939. backend=selective_scan_backend,
  940. )
  941. y: torch.Tensor = cross_merge_fn(ys.view(B, H, W, K, RD), in_channel_first=False, out_channel_first=False, scans=_scan_mode, force_torch=scan_force_torch)
  942. if getattr(self, "__DEBUG__", False):
  943. setattr(self, "__data__", dict(
  944. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=self.Ds,
  945. us=xs, dts=dts, delta_bias=self.dt_projs_bias,
  946. initial_state=self.initial_state, final_satte=final_state,
  947. ys=ys, y=y, H=H, W=W,
  948. ))
  949. if self.initial_state is not None:
  950. self.initial_state = nn.Parameter(final_state.detach().sum(0, keepdim=True), requires_grad=False)
  951. y = self.out_norm(y.view(B, H, W, -1))
  952. return y.to(x.dtype)
  953. def forwardm0(self, x: torch.Tensor, **kwargs):
  954. x = self.in_proj(x)
  955. if not self.disable_z:
  956. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  957. if not self.disable_z_act:
  958. z = self.act(z)
  959. if self.with_dconv:
  960. x = self.conv2d(x) # (b, d, h, w)
  961. x = self.act(x)
  962. y = self.forward_core(x)
  963. y = self.out_act(y)
  964. if not self.disable_z:
  965. y = y * z
  966. out = self.dropout(self.out_proj(y))
  967. return out
  968. class SS2D(nn.Module, SS2Dv0, SS2Dv2, SS2Dv3, SS2Dm0):
  969. def __init__(
  970. self,
  971. # basic dims ===========
  972. d_model=96,
  973. d_state=16,
  974. ssm_ratio=2.0,
  975. dt_rank="auto",
  976. act_layer=nn.SiLU,
  977. # dwconv ===============
  978. d_conv=3, # < 2 means no conv
  979. conv_bias=True,
  980. # ======================
  981. dropout=0.0,
  982. bias=False,
  983. # dt init ==============
  984. dt_min=0.001,
  985. dt_max=0.1,
  986. dt_init="random",
  987. dt_scale=1.0,
  988. dt_init_floor=1e-4,
  989. initialize="v0",
  990. # ======================
  991. forward_type="v2",
  992. channel_first=False,
  993. # ======================
  994. **kwargs,
  995. ):
  996. nn.Module.__init__(self)
  997. kwargs.update(
  998. d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
  999. act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
  1000. dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
  1001. initialize=initialize, forward_type=forward_type, channel_first=channel_first,
  1002. )
  1003. if forward_type in ["v0", "v0seq"]:
  1004. self.__initv0__(seq=("seq" in forward_type), **kwargs)
  1005. elif forward_type.startswith("xv"):
  1006. self.__initxv__(**kwargs)
  1007. elif forward_type.startswith("m"):
  1008. self.__initm0__(**kwargs)
  1009. else:
  1010. self.__initv2__(**kwargs)
  1011. # =====================================================
  1012. class VSSBlock(nn.Module):
  1013. def __init__(
  1014. self,
  1015. hidden_dim: int = 0,
  1016. drop_path: float = 0,
  1017. norm_layer: nn.Module = nn.LayerNorm,
  1018. channel_first=False,
  1019. # =============================
  1020. ssm_d_state: int = 16,
  1021. ssm_ratio=2.0,
  1022. ssm_dt_rank: Any = "auto",
  1023. ssm_act_layer=nn.SiLU,
  1024. ssm_conv: int = 3,
  1025. ssm_conv_bias=True,
  1026. ssm_drop_rate: float = 0,
  1027. ssm_init="v0",
  1028. forward_type="v2",
  1029. # =============================
  1030. mlp_ratio=4.0,
  1031. mlp_act_layer=nn.GELU,
  1032. mlp_drop_rate: float = 0.0,
  1033. gmlp=False,
  1034. # =============================
  1035. use_checkpoint: bool = False,
  1036. post_norm: bool = False,
  1037. # =============================
  1038. _SS2D: type = SS2D,
  1039. **kwargs,
  1040. ):
  1041. super().__init__()
  1042. self.ssm_branch = ssm_ratio > 0
  1043. self.mlp_branch = mlp_ratio > 0
  1044. self.use_checkpoint = use_checkpoint
  1045. self.post_norm = post_norm
  1046. if self.ssm_branch:
  1047. self.norm = norm_layer(hidden_dim)
  1048. self.op = _SS2D(
  1049. d_model=hidden_dim,
  1050. d_state=ssm_d_state,
  1051. ssm_ratio=ssm_ratio,
  1052. dt_rank=ssm_dt_rank,
  1053. act_layer=ssm_act_layer,
  1054. # ==========================
  1055. d_conv=ssm_conv,
  1056. conv_bias=ssm_conv_bias,
  1057. # ==========================
  1058. dropout=ssm_drop_rate,
  1059. # bias=False,
  1060. # ==========================
  1061. # dt_min=0.001,
  1062. # dt_max=0.1,
  1063. # dt_init="random",
  1064. # dt_scale="random",
  1065. # dt_init_floor=1e-4,
  1066. initialize=ssm_init,
  1067. # ==========================
  1068. forward_type=forward_type,
  1069. channel_first=channel_first,
  1070. )
  1071. self.drop_path = DropPath(drop_path)
  1072. if self.mlp_branch:
  1073. _MLP = Mlp if not gmlp else gMlp
  1074. self.norm2 = norm_layer(hidden_dim)
  1075. mlp_hidden_dim = int(hidden_dim * mlp_ratio)
  1076. self.mlp = _MLP(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=channel_first)
  1077. def _forward(self, input: torch.Tensor):
  1078. x = input
  1079. if self.ssm_branch:
  1080. if self.post_norm:
  1081. x = x + self.drop_path(self.norm(self.op(x)))
  1082. else:
  1083. x = x + self.drop_path(self.op(self.norm(x)))
  1084. if self.mlp_branch:
  1085. if self.post_norm:
  1086. x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
  1087. else:
  1088. x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
  1089. return x
  1090. def forward(self, input: torch.Tensor):
  1091. if self.use_checkpoint:
  1092. return checkpoint.checkpoint(self._forward, input)
  1093. else:
  1094. return self._forward(input)
  1095. class VSSM(nn.Module):
  1096. def __init__(
  1097. self,
  1098. patch_size=4,
  1099. in_chans=3,
  1100. num_classes=1000,
  1101. depths=[2, 2, 9, 2],
  1102. dims=[96, 192, 384, 768],
  1103. # =========================
  1104. ssm_d_state=16,
  1105. ssm_ratio=2.0,
  1106. ssm_dt_rank="auto",
  1107. ssm_act_layer="silu",
  1108. ssm_conv=3,
  1109. ssm_conv_bias=True,
  1110. ssm_drop_rate=0.0,
  1111. ssm_init="v0",
  1112. forward_type="v2",
  1113. # =========================
  1114. mlp_ratio=4.0,
  1115. mlp_act_layer="gelu",
  1116. mlp_drop_rate=0.0,
  1117. gmlp=False,
  1118. # =========================
  1119. drop_path_rate=0.1,
  1120. patch_norm=True,
  1121. norm_layer="LN", # "BN", "LN2D"
  1122. downsample_version: str = "v2", # "v1", "v2", "v3"
  1123. patchembed_version: str = "v1", # "v1", "v2"
  1124. use_checkpoint=False,
  1125. # =========================
  1126. posembed=False,
  1127. imgsize=224,
  1128. _SS2D=SS2D,
  1129. # =========================
  1130. **kwargs,
  1131. ):
  1132. super().__init__()
  1133. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1134. self.num_classes = num_classes
  1135. self.num_layers = len(depths)
  1136. if isinstance(dims, int):
  1137. dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
  1138. self.num_features = dims[-1]
  1139. self.dims = dims
  1140. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  1141. _NORMLAYERS = dict(
  1142. ln=nn.LayerNorm,
  1143. ln2d=LayerNorm2d,
  1144. bn=nn.BatchNorm2d,
  1145. )
  1146. _ACTLAYERS = dict(
  1147. silu=nn.SiLU,
  1148. gelu=nn.GELU,
  1149. relu=nn.ReLU,
  1150. sigmoid=nn.Sigmoid,
  1151. )
  1152. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1153. ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
  1154. mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
  1155. self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
  1156. _make_patch_embed = dict(
  1157. v1=self._make_patch_embed,
  1158. v2=self._make_patch_embed_v2,
  1159. ).get(patchembed_version, None)
  1160. self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer, channel_first=self.channel_first)
  1161. _make_downsample = dict(
  1162. v1=PatchMerging2D,
  1163. v2=self._make_downsample,
  1164. v3=self._make_downsample_v3,
  1165. none=(lambda *_, **_k: None),
  1166. ).get(downsample_version, None)
  1167. self.layers = nn.ModuleList()
  1168. for i_layer in range(self.num_layers):
  1169. downsample = _make_downsample(
  1170. self.dims[i_layer],
  1171. self.dims[i_layer + 1],
  1172. norm_layer=norm_layer,
  1173. channel_first=self.channel_first,
  1174. ) if (i_layer < self.num_layers - 1) else nn.Identity()
  1175. self.layers.append(self._make_layer(
  1176. dim = self.dims[i_layer],
  1177. drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  1178. use_checkpoint=use_checkpoint,
  1179. norm_layer=norm_layer,
  1180. downsample=downsample,
  1181. channel_first=self.channel_first,
  1182. # =================
  1183. ssm_d_state=ssm_d_state,
  1184. ssm_ratio=ssm_ratio,
  1185. ssm_dt_rank=ssm_dt_rank,
  1186. ssm_act_layer=ssm_act_layer,
  1187. ssm_conv=ssm_conv,
  1188. ssm_conv_bias=ssm_conv_bias,
  1189. ssm_drop_rate=ssm_drop_rate,
  1190. ssm_init=ssm_init,
  1191. forward_type=forward_type,
  1192. # =================
  1193. mlp_ratio=mlp_ratio,
  1194. mlp_act_layer=mlp_act_layer,
  1195. mlp_drop_rate=mlp_drop_rate,
  1196. gmlp=gmlp,
  1197. # =================
  1198. _SS2D=_SS2D,
  1199. ))
  1200. self.classifier = nn.Sequential(OrderedDict(
  1201. norm=norm_layer(self.num_features), # B,H,W,C
  1202. permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
  1203. avgpool=nn.AdaptiveAvgPool2d(1),
  1204. flatten=nn.Flatten(1),
  1205. head=nn.Linear(self.num_features, num_classes),
  1206. ))
  1207. self.apply(self._init_weights)
  1208. @staticmethod
  1209. def _pos_embed(embed_dims, patch_size, img_size):
  1210. patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
  1211. pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
  1212. trunc_normal_(pos_embed, std=0.02)
  1213. return pos_embed
  1214. def _init_weights(self, m: nn.Module):
  1215. if isinstance(m, nn.Linear):
  1216. trunc_normal_(m.weight, std=.02)
  1217. if isinstance(m, nn.Linear) and m.bias is not None:
  1218. nn.init.constant_(m.bias, 0)
  1219. elif isinstance(m, nn.LayerNorm):
  1220. nn.init.constant_(m.bias, 0)
  1221. nn.init.constant_(m.weight, 1.0)
  1222. # used in building optimizer
  1223. @torch.jit.ignore
  1224. def no_weight_decay(self):
  1225. return {"pos_embed"}
  1226. # used in building optimizer
  1227. @torch.jit.ignore
  1228. def no_weight_decay_keywords(self):
  1229. return {}
  1230. @staticmethod
  1231. def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1232. # if channel first, then Norm and Output are both channel_first
  1233. return nn.Sequential(
  1234. nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
  1235. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1236. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1237. )
  1238. @staticmethod
  1239. def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm, channel_first=False):
  1240. # if channel first, then Norm and Output are both channel_first
  1241. stride = patch_size // 2
  1242. kernel_size = stride + 1
  1243. padding = 1
  1244. return nn.Sequential(
  1245. nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
  1246. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 2, 3, 1)),
  1247. (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
  1248. (nn.Identity() if (channel_first or (not patch_norm)) else Permute(0, 3, 1, 2)),
  1249. nn.GELU(),
  1250. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
  1251. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1252. (norm_layer(embed_dim) if patch_norm else nn.Identity()),
  1253. )
  1254. @staticmethod
  1255. def _make_downsample(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1256. # if channel first, then Norm and Output are both channel_first
  1257. return nn.Sequential(
  1258. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1259. nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
  1260. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1261. norm_layer(out_dim),
  1262. )
  1263. @staticmethod
  1264. def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm, channel_first=False):
  1265. # if channel first, then Norm and Output are both channel_first
  1266. return nn.Sequential(
  1267. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1268. nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
  1269. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1270. norm_layer(out_dim),
  1271. )
  1272. @staticmethod
  1273. def _make_layer(
  1274. dim=96,
  1275. drop_path=[0.1, 0.1],
  1276. use_checkpoint=False,
  1277. norm_layer=nn.LayerNorm,
  1278. downsample=nn.Identity(),
  1279. channel_first=False,
  1280. # ===========================
  1281. ssm_d_state=16,
  1282. ssm_ratio=2.0,
  1283. ssm_dt_rank="auto",
  1284. ssm_act_layer=nn.SiLU,
  1285. ssm_conv=3,
  1286. ssm_conv_bias=True,
  1287. ssm_drop_rate=0.0,
  1288. ssm_init="v0",
  1289. forward_type="v2",
  1290. # ===========================
  1291. mlp_ratio=4.0,
  1292. mlp_act_layer=nn.GELU,
  1293. mlp_drop_rate=0.0,
  1294. gmlp=False,
  1295. # ===========================
  1296. _SS2D=SS2D,
  1297. **kwargs,
  1298. ):
  1299. # if channel first, then Norm and Output are both channel_first
  1300. depth = len(drop_path)
  1301. blocks = []
  1302. for d in range(depth):
  1303. blocks.append(VSSBlock(
  1304. hidden_dim=dim,
  1305. drop_path=drop_path[d],
  1306. norm_layer=norm_layer,
  1307. channel_first=channel_first,
  1308. ssm_d_state=ssm_d_state,
  1309. ssm_ratio=ssm_ratio,
  1310. ssm_dt_rank=ssm_dt_rank,
  1311. ssm_act_layer=ssm_act_layer,
  1312. ssm_conv=ssm_conv,
  1313. ssm_conv_bias=ssm_conv_bias,
  1314. ssm_drop_rate=ssm_drop_rate,
  1315. ssm_init=ssm_init,
  1316. forward_type=forward_type,
  1317. mlp_ratio=mlp_ratio,
  1318. mlp_act_layer=mlp_act_layer,
  1319. mlp_drop_rate=mlp_drop_rate,
  1320. gmlp=gmlp,
  1321. use_checkpoint=use_checkpoint,
  1322. _SS2D=_SS2D,
  1323. ))
  1324. return nn.Sequential(OrderedDict(
  1325. blocks=nn.Sequential(*blocks,),
  1326. downsample=downsample,
  1327. ))
  1328. def forward(self, x: torch.Tensor):
  1329. x = self.patch_embed(x)
  1330. if self.pos_embed is not None:
  1331. pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
  1332. x = x + pos_embed
  1333. for layer in self.layers:
  1334. x = layer(x)
  1335. x = self.classifier(x)
  1336. return x
  1337. def flops(self, shape=(3, 224, 224), verbose=True):
  1338. # shape = self.__input_shape__[1:]
  1339. supported_ops={
  1340. "aten::silu": None, # as relu is in _IGNORED_OPS
  1341. "aten::neg": None, # as relu is in _IGNORED_OPS
  1342. "aten::exp": None, # as relu is in _IGNORED_OPS
  1343. "aten::flip": None, # as permute is in _IGNORED_OPS
  1344. # "prim::PythonOp.CrossScan": None,
  1345. # "prim::PythonOp.CrossMerge": None,
  1346. "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
  1347. }
  1348. model = copy.deepcopy(self)
  1349. model.cuda().eval()
  1350. input = torch.randn((1, *shape), device=next(model.parameters()).device)
  1351. params = parameter_count(model)[""]
  1352. Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
  1353. del model, input
  1354. return sum(Gflops.values()) * 1e9
  1355. return f"params {params} GFLOPs {sum(Gflops.values())}"
  1356. # used to load ckpt from previous training code
  1357. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  1358. def check_name(src, state_dict: dict = state_dict, strict=False):
  1359. if strict:
  1360. if prefix + src in list(state_dict.keys()):
  1361. return True
  1362. else:
  1363. key = prefix + src
  1364. for k in list(state_dict.keys()):
  1365. if k.startswith(key):
  1366. return True
  1367. return False
  1368. def change_name(src, dst, state_dict: dict = state_dict, strict=False):
  1369. if strict:
  1370. if prefix + src in list(state_dict.keys()):
  1371. state_dict[prefix + dst] = state_dict[prefix + src]
  1372. state_dict.pop(prefix + src)
  1373. else:
  1374. key = prefix + src
  1375. for k in list(state_dict.keys()):
  1376. if k.startswith(key):
  1377. new_k = prefix + dst + k[len(key):]
  1378. state_dict[new_k] = state_dict[k]
  1379. state_dict.pop(k)
  1380. if check_name("pos_embed", strict=True):
  1381. srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
  1382. state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
  1383. change_name("patch_embed.proj", "patch_embed.0")
  1384. change_name("patch_embed.norm", "patch_embed.2")
  1385. for i in range(100):
  1386. for j in range(100):
  1387. change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
  1388. change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
  1389. change_name("norm", "classifier.norm")
  1390. change_name("head", "classifier.head")
  1391. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  1392. # compatible with openmmlab
  1393. class Backbone_VSSM(VSSM):
  1394. def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
  1395. kwargs.update(norm_layer=norm_layer)
  1396. super().__init__(**kwargs)
  1397. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1398. _NORMLAYERS = dict(
  1399. ln=nn.LayerNorm,
  1400. ln2d=LayerNorm2d,
  1401. bn=nn.BatchNorm2d,
  1402. )
  1403. norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
  1404. self.out_indices = out_indices
  1405. for i in out_indices:
  1406. layer = norm_layer(self.dims[i])
  1407. layer_name = f'outnorm{i}'
  1408. self.add_module(layer_name, layer)
  1409. del self.classifier
  1410. self.load_pretrained(pretrained)
  1411. def load_pretrained(self, ckpt=None, key="model"):
  1412. if ckpt is None:
  1413. return
  1414. try:
  1415. _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
  1416. print(f"Successfully load ckpt {ckpt}")
  1417. incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
  1418. print(incompatibleKeys)
  1419. except Exception as e:
  1420. print(f"Failed loading checkpoint form {ckpt}: {e}")
  1421. def forward(self, x):
  1422. def layer_forward(l, x):
  1423. x = l.blocks(x)
  1424. y = l.downsample(x)
  1425. return x, y
  1426. x = self.patch_embed(x)
  1427. outs = []
  1428. for i, layer in enumerate(self.layers):
  1429. o, x = layer_forward(layer, x) # (B, H, W, C)
  1430. if i in self.out_indices:
  1431. norm_layer = getattr(self, f'outnorm{i}')
  1432. out = norm_layer(o)
  1433. if not self.channel_first:
  1434. out = out.permute(0, 3, 1, 2)
  1435. outs.append(out.contiguous())
  1436. if len(self.out_indices) == 0:
  1437. return x
  1438. return outs
  1439. # =====================================================
  1440. def vanilla_vmamba_tiny():
  1441. return VSSM(
  1442. depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
  1443. patch_size=4, in_chans=3, num_classes=1000,
  1444. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1445. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1446. ssm_init="v0", forward_type="v0",
  1447. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1448. patch_norm=True, norm_layer="ln",
  1449. downsample_version="v1", patchembed_version="v1",
  1450. use_checkpoint=False, posembed=False, imgsize=224,
  1451. )
  1452. def vanilla_vmamba_small():
  1453. return VSSM(
  1454. depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
  1455. patch_size=4, in_chans=3, num_classes=1000,
  1456. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1457. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1458. ssm_init="v0", forward_type="v0",
  1459. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1460. patch_norm=True, norm_layer="ln",
  1461. downsample_version="v1", patchembed_version="v1",
  1462. use_checkpoint=False, posembed=False, imgsize=224,
  1463. )
  1464. def vanilla_vmamba_base():
  1465. return VSSM(
  1466. depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
  1467. patch_size=4, in_chans=3, num_classes=1000,
  1468. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1469. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1470. ssm_init="v0", forward_type="v0",
  1471. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1472. patch_norm=True, norm_layer="ln",
  1473. downsample_version="v1", patchembed_version="v1",
  1474. use_checkpoint=False, posembed=False, imgsize=224,
  1475. )
  1476. # =====================================================
  1477. def vmamba_tiny_s2l5(channel_first=True):
  1478. return VSSM(
  1479. depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
  1480. patch_size=4, in_chans=3, num_classes=1000,
  1481. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1482. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1483. ssm_init="v0", forward_type="v05_noz",
  1484. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1485. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1486. downsample_version="v3", patchembed_version="v2",
  1487. use_checkpoint=False, posembed=False, imgsize=224,
  1488. )
  1489. def vmamba_small_s2l15(channel_first=True):
  1490. return VSSM(
  1491. depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
  1492. patch_size=4, in_chans=3, num_classes=1000,
  1493. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1494. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1495. ssm_init="v0", forward_type="v05_noz",
  1496. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1497. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1498. downsample_version="v3", patchembed_version="v2",
  1499. use_checkpoint=False, posembed=False, imgsize=224,
  1500. )
  1501. def vmamba_base_s2l15(channel_first=True):
  1502. return VSSM(
  1503. depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
  1504. patch_size=4, in_chans=3, num_classes=1000,
  1505. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1506. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1507. ssm_init="v0", forward_type="v05_noz",
  1508. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1509. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1510. downsample_version="v3", patchembed_version="v2",
  1511. use_checkpoint=False, posembed=False, imgsize=224,
  1512. )
  1513. # =====================================================
  1514. def vmamba_tiny_s1l8(channel_first=True):
  1515. return VSSM(
  1516. depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
  1517. patch_size=4, in_chans=3, num_classes=1000,
  1518. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1519. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1520. ssm_init="v0", forward_type="v05_noz",
  1521. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1522. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1523. downsample_version="v3", patchembed_version="v2",
  1524. use_checkpoint=False, posembed=False, imgsize=224,
  1525. )
  1526. def vmamba_small_s1l20(channel_first=True):
  1527. return VSSM(
  1528. depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
  1529. patch_size=4, in_chans=3, num_classes=1000,
  1530. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1531. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1532. ssm_init="v0", forward_type="v05_noz",
  1533. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1534. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1535. downsample_version="v3", patchembed_version="v2",
  1536. use_checkpoint=False, posembed=False, imgsize=224,
  1537. )
  1538. def vmamba_base_s1l20(channel_first=True):
  1539. return VSSM(
  1540. depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
  1541. patch_size=4, in_chans=3, num_classes=1000,
  1542. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1543. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1544. ssm_init="v0", forward_type="v05_noz",
  1545. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1546. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1547. downsample_version="v3", patchembed_version="v2",
  1548. use_checkpoint=False, posembed=False, imgsize=224,
  1549. )
  1550. # mamba2 support =====================================================
  1551. # FLOPS count do not work now for mamba2!
  1552. def vmamba_tiny_m2():
  1553. return VSSM(
  1554. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1555. patch_size=4, in_chans=3, num_classes=1000,
  1556. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1557. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1558. ssm_init="v2", forward_type="m0_noz",
  1559. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1560. patch_norm=True, norm_layer="ln",
  1561. downsample_version="v3", patchembed_version="v2",
  1562. use_checkpoint=False, posembed=False, imgsize=224,
  1563. )
  1564. def vmamba_small_m2():
  1565. return VSSM(
  1566. depths=[2, 2, 12, 2], dims=96, drop_path_rate=0.3,
  1567. patch_size=4, in_chans=3, num_classes=1000,
  1568. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1569. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1570. ssm_init="v2", forward_type="m0_noz",
  1571. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1572. patch_norm=True, norm_layer="ln",
  1573. downsample_version="v3", patchembed_version="v2",
  1574. use_checkpoint=False, posembed=False, imgsize=224,
  1575. )
  1576. def vmamba_base_m2():
  1577. return VSSM(
  1578. depths=[2, 2, 12, 2], dims=128, drop_path_rate=0.3,
  1579. patch_size=4, in_chans=3, num_classes=1000,
  1580. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1581. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1582. ssm_init="v2", forward_type="m0_noz",
  1583. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1584. patch_norm=True, norm_layer="ln",
  1585. downsample_version="v3", patchembed_version="v2",
  1586. use_checkpoint=False, posembed=False, imgsize=224,
  1587. )
  1588. if __name__ == "__main__":
  1589. model_ref = vmamba_tiny_s1l8()
  1590. model = VSSM(
  1591. depths=[2, 2, 4, 2], dims=96, drop_path_rate=0.2,
  1592. patch_size=4, in_chans=3, num_classes=1000,
  1593. ssm_d_state=64, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="gelu",
  1594. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1595. ssm_init="v2", forward_type="m0_noz",
  1596. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1597. patch_norm=True, norm_layer="ln",
  1598. downsample_version="v3", patchembed_version="v2",
  1599. use_checkpoint=False, posembed=False, imgsize=224,
  1600. )
  1601. print(parameter_count(model)[""])
  1602. print(model.flops()) # wrong
  1603. model.cuda().train()
  1604. model_ref.cuda().train()
  1605. def bench(model):
  1606. import time
  1607. inp = torch.randn((128, 3, 224, 224)).cuda()
  1608. for _ in range(30):
  1609. model(inp)
  1610. torch.cuda.synchronize()
  1611. tim = time.time()
  1612. for _ in range(30):
  1613. model(inp)
  1614. torch.cuda.synchronize()
  1615. tim1 = time.time() - tim
  1616. for _ in range(30):
  1617. model(inp).sum().backward()
  1618. torch.cuda.synchronize()
  1619. tim = time.time()
  1620. for _ in range(30):
  1621. model(inp).sum().backward()
  1622. torch.cuda.synchronize()
  1623. tim2 = time.time() - tim
  1624. return tim1 / 30, tim2 / 30
  1625. print(bench(model_ref))
  1626. print(bench(model))
  1627. breakpoint()