vmamba.py 101 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483
  1. ##########################################################
  2. # simplified version
  3. # just one file and include everything
  4. # written by MzeroMiko
  5. ##########################################################
  6. ##########################################################
  7. # usage:
  8. # conda create -n vmamba python=3.10
  9. # pip install torch==2.2 torchvision torchaudio triton pytest chardet yacs termcolor fvcore seaborn packaging ninja einops numpy==1.24.4 timm==0.4.12
  10. # pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.2cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
  11. # python vmamba.py
  12. ##########################################################
  13. ##########################################################
  14. # csm_triton.py
  15. ##########################################################
  16. import torch
  17. import warnings
  18. WITH_TRITON = True
  19. # WITH_TRITON = False
  20. try:
  21. import triton
  22. import triton.language as tl
  23. except:
  24. WITH_TRITON = False
  25. warnings.warn("Triton not installed, fall back to pytorch implements.")
  26. # to make sure cached_property can be loaded for triton
  27. if WITH_TRITON:
  28. try:
  29. from functools import cached_property
  30. except:
  31. warnings.warn("if you are using py37, add this line to functools.py: "
  32. "cached_property = lambda func: property(lru_cache()(func))")
  33. # torch implementation ========================================
  34. def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  35. if in_channel_first:
  36. B, C, H, W = x.shape
  37. if scans == 0:
  38. y = x.new_empty((B, 4, C, H * W))
  39. y[:, 0, :, :] = x.flatten(2, 3)
  40. y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
  41. y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
  42. elif scans == 1:
  43. y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
  44. elif scans == 2:
  45. y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
  46. y = torch.cat([y, y.flip(dims=[-1])], dim=1)
  47. elif scans == 3:
  48. y = x.new_empty((B, 4, C, H * W))
  49. y[:, 0, :, :] = x.flatten(2, 3)
  50. y[:, 1, :, :] = torch.rot90(x, 1, dims=(2, 3)).flatten(2, 3)
  51. y[:, 2, :, :] = torch.rot90(x, 2, dims=(2, 3)).flatten(2, 3)
  52. y[:, 3, :, :] = torch.rot90(x, 3, dims=(2, 3)).flatten(2, 3)
  53. else:
  54. B, H, W, C = x.shape
  55. if scans == 0:
  56. y = x.new_empty((B, H * W, 4, C))
  57. y[:, :, 0, :] = x.flatten(1, 2)
  58. y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
  59. y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
  60. elif scans == 1:
  61. y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
  62. elif scans == 2:
  63. y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
  64. y = torch.cat([y, y.flip(dims=[1])], dim=2)
  65. elif scans == 3:
  66. y = x.new_empty((B, H * W, 4, C))
  67. y[:, :, 0, :] = x.flatten(1, 2)
  68. y[:, :, 1, :] = torch.rot90(x, 1, dims=(1, 2)).flatten(1, 2)
  69. y[:, :, 2, :] = torch.rot90(x, 2, dims=(1, 2)).flatten(1, 2)
  70. y[:, :, 3, :] = torch.rot90(x, 3, dims=(1, 2)).flatten(1, 2)
  71. if in_channel_first and (not out_channel_first):
  72. y = y.permute(0, 3, 1, 2).contiguous()
  73. elif (not in_channel_first) and out_channel_first:
  74. y = y.permute(0, 2, 3, 1).contiguous()
  75. return y
  76. def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  77. if out_channel_first:
  78. B, K, D, H, W = y.shape
  79. y = y.view(B, K, D, -1)
  80. if scans == 0:
  81. y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  82. y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
  83. elif scans == 1:
  84. y = y.sum(1)
  85. elif scans == 2:
  86. y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  87. y = y.sum(1)
  88. elif scans == 3:
  89. oy = y[:, 0, :, :].contiguous().view(B, D, -1)
  90. oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3)
  91. oy = oy + torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3)
  92. oy = oy + torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3)
  93. y = oy
  94. else:
  95. B, H, W, K, D = y.shape
  96. y = y.view(B, -1, K, D)
  97. if scans == 0:
  98. y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
  99. y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
  100. elif scans == 1:
  101. y = y.sum(2)
  102. elif scans == 2:
  103. y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
  104. y = y.sum(2)
  105. elif scans == 3:
  106. oy = y[:, :, 0, :].contiguous().view(B, -1, D)
  107. oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2)
  108. oy = oy + torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2)
  109. oy = oy + torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2)
  110. y = oy
  111. if in_channel_first and (not out_channel_first):
  112. y = y.permute(0, 2, 1).contiguous()
  113. elif (not in_channel_first) and out_channel_first:
  114. y = y.permute(0, 2, 1).contiguous()
  115. return y
  116. def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  117. if in_channel_first:
  118. B, _, C, H, W = x.shape
  119. if scans == 0:
  120. y = torch.stack([
  121. x[:, 0].flatten(2, 3),
  122. x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
  123. torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
  124. torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
  125. ], dim=1)
  126. elif scans == 1:
  127. y = x.flatten(2, 3)
  128. elif scans == 2:
  129. y = torch.stack([
  130. x[:, 0].flatten(2, 3),
  131. x[:, 1].flatten(2, 3),
  132. torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
  133. torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
  134. ], dim=1)
  135. elif scans == 3:
  136. y = torch.stack([
  137. x[:, 0, :, :, :].flatten(2, 3),
  138. torch.rot90(x[:, 1, :, :, :], 1, dims=(2, 3)).flatten(2, 3),
  139. torch.rot90(x[:, 2, :, :, :], 2, dims=(2, 3)).flatten(2, 3),
  140. torch.rot90(x[:, 3, :, :, :], 3, dims=(2, 3)).flatten(2, 3),
  141. ], dim=1)
  142. else:
  143. B, H, W, _, C = x.shape
  144. if scans == 0:
  145. y = torch.stack([
  146. x[:, :, :, 0].flatten(1, 2),
  147. x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
  148. torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
  149. torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
  150. ], dim=2)
  151. elif scans == 1:
  152. y = x.flatten(1, 2)
  153. elif scans == 2:
  154. y = torch.stack([
  155. x[:, 0].flatten(1, 2),
  156. x[:, 1].flatten(1, 2),
  157. torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
  158. torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
  159. ], dim=2)
  160. elif scans == 3:
  161. y = torch.stack([
  162. x[:, :, :, 0, :].flatten(1, 2),
  163. torch.rot90(x[:, :, :, 1, :], 1, dims=(1, 2)).flatten(1, 2),
  164. torch.rot90(x[:, :, :, 2, :], 2, dims=(1, 2)).flatten(1, 2),
  165. torch.rot90(x[:, :, :, 3, :], 3, dims=(1, 2)).flatten(1, 2),
  166. ], dim=1)
  167. if in_channel_first and (not out_channel_first):
  168. y = y.permute(0, 3, 1, 2).contiguous()
  169. elif (not in_channel_first) and out_channel_first:
  170. y = y.permute(0, 2, 3, 1).contiguous()
  171. return y
  172. def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  173. if out_channel_first:
  174. B, K, D, H, W = y.shape
  175. y = y.view(B, K, D, -1)
  176. if scans == 0:
  177. y = torch.stack([
  178. y[:, 0],
  179. y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
  180. torch.flip(y[:, 2], dims=[-1]),
  181. torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
  182. ], dim=1)
  183. elif scans == 1:
  184. y = y
  185. elif scans == 2:
  186. y = torch.stack([
  187. y[:, 0],
  188. y[:, 1],
  189. torch.flip(y[:, 2], dims=[-1]),
  190. torch.flip(y[:, 3], dims=[-1]),
  191. ], dim=1)
  192. elif scans == 3:
  193. y = torch.stack([
  194. y[:, 0, :, :].contiguous().view(B, D, -1),
  195. torch.rot90(y.view(B, K, D, W, H)[:, 1, :, :, :], -1, dims=(2, 3)).flatten(2, 3),
  196. torch.rot90(y.view(B, K, D, H, W)[:, 2, :, :, :], -2, dims=(2, 3)).flatten(2, 3),
  197. torch.rot90(y.view(B, K, D, W, H)[:, 3, :, :, :], -3, dims=(2, 3)).flatten(2, 3),
  198. ], dim=1)
  199. else:
  200. B, H, W, K, D = y.shape
  201. y = y.view(B, -1, K, D)
  202. if scans == 0:
  203. y = torch.stack([
  204. y[:, :, 0],
  205. y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
  206. torch.flip(y[:, :, 2], dims=[1]),
  207. torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
  208. ], dim=2)
  209. elif scans == 1:
  210. y = y
  211. elif scans == 2:
  212. y = torch.stack([
  213. y[:, :, 0],
  214. y[:, :, 1],
  215. torch.flip(y[:, :, 2], dims=[1]),
  216. torch.flip(y[:, :, 3], dims=[1]),
  217. ], dim=2)
  218. elif scans == 3:
  219. y = torch.stack([
  220. y[:, :, 0, :].contiguous().view(B, -1, D),
  221. torch.rot90(y.view(B, W, H, K, D)[:, :, :, 1, :], -1, dims=(1, 2)).flatten(1, 2),
  222. torch.rot90(y.view(B, H, W, K, D)[:, :, :, 2, :], -2, dims=(1, 2)).flatten(1, 2),
  223. torch.rot90(y.view(B, W, H, K, D)[:, :, :, 3, :], -3, dims=(1, 2)).flatten(1, 2),
  224. ], dim=2)
  225. if out_channel_first and (not in_channel_first):
  226. y = y.permute(0, 3, 1, 2).contiguous()
  227. elif (not out_channel_first) and in_channel_first:
  228. y = y.permute(0, 2, 3, 1).contiguous()
  229. return y
  230. class CrossScanF(torch.autograd.Function):
  231. @staticmethod
  232. def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  233. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  234. # y: (B, 4, C, H * W) | (B, H * W, 4, C)
  235. ctx.in_channel_first = in_channel_first
  236. ctx.out_channel_first = out_channel_first
  237. ctx.one_by_one = one_by_one
  238. ctx.scans = scans
  239. if one_by_one:
  240. B, K, C, H, W = x.shape
  241. if not in_channel_first:
  242. B, H, W, K, C = x.shape
  243. else:
  244. B, C, H, W = x.shape
  245. if not in_channel_first:
  246. B, H, W, C = x.shape
  247. ctx.shape = (B, C, H, W)
  248. _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
  249. y = _fn(x, in_channel_first, out_channel_first, scans)
  250. return y
  251. @staticmethod
  252. def backward(ctx, ys: torch.Tensor):
  253. # out: (b, k, d, l)
  254. in_channel_first = ctx.in_channel_first
  255. out_channel_first = ctx.out_channel_first
  256. one_by_one = ctx.one_by_one
  257. scans = ctx.scans
  258. B, C, H, W = ctx.shape
  259. ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
  260. _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
  261. y = _fn(ys, in_channel_first, out_channel_first, scans)
  262. if one_by_one:
  263. y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
  264. else:
  265. y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
  266. return y, None, None, None, None
  267. class CrossMergeF(torch.autograd.Function):
  268. @staticmethod
  269. def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  270. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  271. # y: (B, 4, C, H * W) | (B, H * W, 4, C)
  272. ctx.in_channel_first = in_channel_first
  273. ctx.out_channel_first = out_channel_first
  274. ctx.one_by_one = one_by_one
  275. ctx.scans = scans
  276. B, K, C, H, W = ys.shape
  277. if not out_channel_first:
  278. B, H, W, K, C = ys.shape
  279. ctx.shape = (B, C, H, W)
  280. _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
  281. y = _fn(ys, in_channel_first, out_channel_first, scans)
  282. return y
  283. @staticmethod
  284. def backward(ctx, x: torch.Tensor):
  285. # B, D, L = x.shape
  286. # out: (b, k, d, h, w)
  287. in_channel_first = ctx.in_channel_first
  288. out_channel_first = ctx.out_channel_first
  289. one_by_one = ctx.one_by_one
  290. scans = ctx.scans
  291. B, C, H, W = ctx.shape
  292. if not one_by_one:
  293. if in_channel_first:
  294. x = x.view(B, C, H, W)
  295. else:
  296. x = x.view(B, H, W, C)
  297. else:
  298. if in_channel_first:
  299. x = x.view(B, 4, C, H, W)
  300. else:
  301. x = x.view(B, H, W, 4, C)
  302. _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
  303. x = _fn(x, in_channel_first, out_channel_first, scans)
  304. x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
  305. return x, None, None, None, None
  306. # triton implements ========================================
  307. @triton.jit
  308. def triton_cross_scan_flex(
  309. x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  310. y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
  311. x_layout: tl.constexpr,
  312. y_layout: tl.constexpr,
  313. operation: tl.constexpr,
  314. onebyone: tl.constexpr,
  315. scans: tl.constexpr,
  316. BC: tl.constexpr,
  317. BH: tl.constexpr,
  318. BW: tl.constexpr,
  319. DC: tl.constexpr,
  320. DH: tl.constexpr,
  321. DW: tl.constexpr,
  322. NH: tl.constexpr,
  323. NW: tl.constexpr,
  324. ):
  325. # x_layout = 0
  326. # y_layout = 1 # 0 BCHW, 1 BHWC
  327. # operation = 0 # 0 scan, 1 merge
  328. # onebyone = 0 # 0 false, 1 true
  329. # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
  330. i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
  331. i_h, i_w = (i_hw // NW), (i_hw % NW)
  332. _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
  333. _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
  334. _mask_hw = _mask_h[:, None] & _mask_w[None, :]
  335. _for_C = min(DC - i_c * BC, BC)
  336. pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
  337. pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
  338. neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
  339. neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
  340. if scans == 0:
  341. # none; trans; flip; trans + flip;
  342. HWRoute0 = pos_h * DW + pos_w
  343. HWRoute1 = pos_w * DH + pos_h # trans
  344. HWRoute2 = neg_h * DW + neg_w # flip
  345. HWRoute3 = neg_w * DH + neg_h # trans + flip
  346. elif scans == 1:
  347. # none; none; none; none;
  348. HWRoute0 = pos_h * DW + pos_w
  349. HWRoute1 = HWRoute0
  350. HWRoute2 = HWRoute0
  351. HWRoute3 = HWRoute0
  352. elif scans == 2:
  353. # none; none; flip; flip;
  354. HWRoute0 = pos_h * DW + pos_w
  355. HWRoute1 = HWRoute0
  356. HWRoute2 = neg_h * DW + neg_w # flip
  357. HWRoute3 = HWRoute2
  358. elif scans == 3:
  359. # none; rot90; rot180==flip; rot270;
  360. HWRoute0 = pos_h * DW + pos_w
  361. HWRoute1 = neg_w * DH + pos_h
  362. HWRoute2 = neg_h * DW + neg_w
  363. HWRoute3 = pos_w * DH + neg_h
  364. _tmp1 = DC * DH * DW
  365. y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
  366. if y_layout == 0:
  367. p_y1 = y_ptr_base + HWRoute0
  368. p_y2 = y_ptr_base + _tmp1 + HWRoute1
  369. p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
  370. p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
  371. else:
  372. p_y1 = y_ptr_base + HWRoute0 * 4 * DC
  373. p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
  374. p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
  375. p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
  376. if onebyone == 0:
  377. x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
  378. if x_layout == 0:
  379. p_x = x_ptr_base + HWRoute0
  380. else:
  381. p_x = x_ptr_base + HWRoute0 * DC
  382. if operation == 0:
  383. for idxc in range(_for_C):
  384. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  385. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  386. _x = tl.load(p_x + _idx_x, mask=_mask_hw)
  387. tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
  388. tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
  389. tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
  390. tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
  391. elif operation == 1:
  392. for idxc in range(_for_C):
  393. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  394. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  395. _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
  396. _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
  397. _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
  398. _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
  399. tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
  400. else:
  401. x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
  402. if x_layout == 0:
  403. p_x1 = x_ptr_base + HWRoute0
  404. p_x2 = p_x1 + _tmp1
  405. p_x3 = p_x2 + _tmp1
  406. p_x4 = p_x3 + _tmp1
  407. else:
  408. p_x1 = x_ptr_base + HWRoute0 * 4 * DC
  409. p_x2 = p_x1 + DC
  410. p_x3 = p_x2 + DC
  411. p_x4 = p_x3 + DC
  412. if operation == 0:
  413. for idxc in range(_for_C):
  414. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  415. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  416. tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  417. tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  418. tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  419. tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  420. else:
  421. for idxc in range(_for_C):
  422. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  423. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  424. tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
  425. tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
  426. tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
  427. tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
  428. class CrossScanTritonF(torch.autograd.Function):
  429. @staticmethod
  430. def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  431. if one_by_one:
  432. if in_channel_first:
  433. B, _, C, H, W = x.shape
  434. else:
  435. B, H, W, _, C = x.shape
  436. else:
  437. if in_channel_first:
  438. B, C, H, W = x.shape
  439. else:
  440. B, H, W, C = x.shape
  441. B, C, H, W = int(B), int(C), int(H), int(W)
  442. BC, BH, BW = 1, 32, 32
  443. NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
  444. ctx.in_channel_first = in_channel_first
  445. ctx.out_channel_first = out_channel_first
  446. ctx.one_by_one = one_by_one
  447. ctx.scans = scans
  448. ctx.shape = (B, C, H, W)
  449. ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
  450. y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
  451. triton_cross_scan_flex[(NH * NW, NC, B)](
  452. x.contiguous(), y,
  453. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
  454. BC, BH, BW, C, H, W, NH, NW
  455. )
  456. return y
  457. @staticmethod
  458. def backward(ctx, y: torch.Tensor):
  459. in_channel_first = ctx.in_channel_first
  460. out_channel_first = ctx.out_channel_first
  461. one_by_one = ctx.one_by_one
  462. scans = ctx.scans
  463. B, C, H, W = ctx.shape
  464. BC, BH, BW, NC, NH, NW = ctx.triton_shape
  465. if one_by_one:
  466. x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
  467. else:
  468. x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
  469. triton_cross_scan_flex[(NH * NW, NC, B)](
  470. x, y.contiguous(),
  471. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
  472. BC, BH, BW, C, H, W, NH, NW
  473. )
  474. return x, None, None, None, None
  475. class CrossMergeTritonF(torch.autograd.Function):
  476. @staticmethod
  477. def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  478. if out_channel_first:
  479. B, _, C, H, W = y.shape
  480. else:
  481. B, H, W, _, C = y.shape
  482. B, C, H, W = int(B), int(C), int(H), int(W)
  483. BC, BH, BW = 1, 32, 32
  484. NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
  485. ctx.in_channel_first = in_channel_first
  486. ctx.out_channel_first = out_channel_first
  487. ctx.one_by_one = one_by_one
  488. ctx.scans = scans
  489. ctx.shape = (B, C, H, W)
  490. ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
  491. if one_by_one:
  492. x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
  493. else:
  494. x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
  495. triton_cross_scan_flex[(NH * NW, NC, B)](
  496. x, y.contiguous(),
  497. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
  498. BC, BH, BW, C, H, W, NH, NW
  499. )
  500. return x
  501. @staticmethod
  502. def backward(ctx, x: torch.Tensor):
  503. in_channel_first = ctx.in_channel_first
  504. out_channel_first = ctx.out_channel_first
  505. one_by_one = ctx.one_by_one
  506. scans = ctx.scans
  507. B, C, H, W = ctx.shape
  508. BC, BH, BW, NC, NH, NW = ctx.triton_shape
  509. y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
  510. triton_cross_scan_flex[(NH * NW, NC, B)](
  511. x.contiguous(), y,
  512. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
  513. BC, BH, BW, C, H, W, NH, NW
  514. )
  515. return y, None, None, None, None, None
  516. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  517. def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
  518. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  519. # y: (B, 4, C, L) | (B, L, 4, C)
  520. # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
  521. CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
  522. if x.is_cuda:
  523. with torch.cuda.device(x.device):
  524. return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
  525. else:
  526. return CrossScanF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
  527. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  528. def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
  529. # y: (B, 4, C, L) | (B, L, 4, C)
  530. # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
  531. # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
  532. CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
  533. if y.is_cuda:
  534. with torch.cuda.device(y.device):
  535. return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
  536. else:
  537. return CrossMergeF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
  538. # checks =================================================================
  539. # class CHECK:
  540. # def check_csm_triton():
  541. # B, C, H, W = 256, 192, 56, 57
  542. # dtype=torch.float16
  543. # dtype=torch.float32
  544. # x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
  545. # y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
  546. # x1 = x.clone().detach().requires_grad_(True)
  547. # y1 = y.clone().detach().requires_grad_(True)
  548. # def cross_scan(x: torch.Tensor):
  549. # B, C, H, W = x.shape
  550. # L = H * W
  551. # xs = torch.stack([
  552. # x.view(B, C, L),
  553. # torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
  554. # torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
  555. # torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
  556. # ], dim=1).view(B, 4, C, L)
  557. # return xs
  558. # def cross_merge(out_y: torch.Tensor):
  559. # B, K, D, H, W = out_y.shape
  560. # L = H * W
  561. # out_y = out_y.view(B, K, D, L)
  562. # inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  563. # wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  564. # invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  565. # y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  566. # return y
  567. # def cross_scan_1b1(x: torch.Tensor):
  568. # B, K, C, H, W = x.shape
  569. # L = H * W
  570. # xs = torch.stack([
  571. # x[:, 0].view(B, C, L),
  572. # torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
  573. # torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
  574. # torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
  575. # ], dim=1).view(B, 4, C, L)
  576. # return xs
  577. # def unidi_scan(x):
  578. # B, C, H, W = x.shape
  579. # x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
  580. # return x
  581. # def unidi_merge(ys):
  582. # B, K, C, H, W = ys.shape
  583. # return ys.view(B, 4, -1, H * W).sum(1)
  584. # def bidi_scan(x):
  585. # B, C, H, W = x.shape
  586. # x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
  587. # x = torch.cat([x, x.flip(dims=[-1])], dim=1)
  588. # return x
  589. # def bidi_merge(ys):
  590. # B, K, D, H, W = ys.shape
  591. # ys = ys.view(B, K, D, -1)
  592. # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  593. # return ys.contiguous().sum(1)
  594. # if True:
  595. # res0 = triton.testing.do_bench(lambda :cross_scan(x))
  596. # res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
  597. # # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
  598. # res3 = triton.testing.do_bench(lambda :cross_merge(y))
  599. # res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
  600. # # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
  601. # # print(res0, res1, res2, res3, res4, res5)
  602. # print(res0, res1, res3, res4)
  603. # res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
  604. # res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
  605. # # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
  606. # res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
  607. # res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
  608. # # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
  609. # # print(res0, res1, res2, res3, res4, res5)
  610. # print(res0, res1, res3, res4)
  611. # print("test cross scan")
  612. # for (cs0, cm0, cs1, cm1) in [
  613. # # channel_first -> channel_first
  614. # (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
  615. # (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
  616. # (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
  617. # # flex: BLC->BCL; BCL->BLC; BLC->BLC;
  618. # (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
  619. # (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
  620. # (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
  621. # # previous
  622. # # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
  623. # # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
  624. # # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
  625. # ]:
  626. # x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  627. # o0 = cs0(x)
  628. # o1 = cs1(x1)
  629. # o0.backward(y.view(B, 4, C, H * W))
  630. # o1.backward(y.view(B, 4, C, H * W))
  631. # print((o0 - o1).abs().max())
  632. # print((x.grad - x1.grad).abs().max())
  633. # o0 = cm0(y)
  634. # o1 = cm1(y1)
  635. # o0.backward(x.view(B, C, H * W))
  636. # o1.backward(x.view(B, C, H * W))
  637. # print((o0 - o1).abs().max())
  638. # print((y.grad - y1.grad).abs().max())
  639. # x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  640. # print("===============", flush=True)
  641. # print("test cross scan one by one")
  642. # for (cs0, cs1) in [
  643. # (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
  644. # # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
  645. # ]:
  646. # o0 = cs0(y)
  647. # o1 = cs1(y1)
  648. # o0.backward(y.view(B, 4, C, H * W))
  649. # o1.backward(y.view(B, 4, C, H * W))
  650. # print((o0 - o1).abs().max())
  651. # print((y.grad - y1.grad).abs().max())
  652. # x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  653. # print("===============", flush=True)
  654. # def check_csm_scan3():
  655. # if False:
  656. # x = torch.arange(0, 16).view(1, 1, 4, 4).cuda()
  657. # out1 = cross_scan_fn(x, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  658. # out2 = cross_merge_fn(out1, scans=3, force_torch=True).view(1, 1, 4, 4)
  659. # out4 = cross_merge_fn(out1, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  660. # out3 = cross_scan_fn(out4, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  661. # out5 = cross_scan_fn(x.view(1, 4, 4, 1), in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  662. # out6 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 1)
  663. # out8 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  664. # out7 = cross_scan_fn(out8, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  665. # print(out1.view(4, -1))
  666. # print(out2.view(-1))
  667. # print(out3.view(4, -1))
  668. # print(out4.view(4, -1))
  669. # print(out5.view(-1, 4).t())
  670. # print(out6.view(-1))
  671. # print(out7.view(-1, 4).t())
  672. # print(out8.view(-1, 4).t())
  673. # B, C, H, W = 27, 253, 57, 58
  674. # x = torch.randn((B, C, H, W)).cuda()
  675. # for scans in [0, 1, 2, 3]:
  676. # o1 = cross_scan_fn(x, scans=scans, force_torch=True).view(B, 4, C, H, W)
  677. # print((cross_scan_fn(x, scans=scans) == cross_scan_fn(x, scans=scans, force_torch=True)).all())
  678. # print((cross_merge_fn(o1, scans=scans) == cross_merge_fn(o1, scans=scans, force_torch=True)).all())
  679. # kwargs = dict(in_channel_first=False, out_channel_first=False)
  680. # x2 = x.permute(0, 2, 3, 1).contiguous()
  681. # o2 = o1.permute(0, 3, 4, 1, 2).contiguous()
  682. # print((cross_scan_fn(x, scans=scans, **kwargs) == cross_scan_fn(x, scans=scans, force_torch=True, **kwargs)).all())
  683. # print((cross_merge_fn(o2, scans=scans, **kwargs) == cross_merge_fn(o2, scans=scans, force_torch=True, **kwargs)).all())
  684. # breakpoint()
  685. # if __name__ == "__main__":
  686. # CHECK.check_csm_scan3()
  687. # CHECK.check_csm_triton()
  688. ##########################################################
  689. # csms6s.py
  690. ##########################################################
  691. import time
  692. import torch
  693. import warnings
  694. WITH_SELECTIVESCAN_MAMBA = True
  695. try:
  696. import selective_scan_cuda
  697. except ImportError:
  698. WITH_SELECTIVESCAN_MAMBA = False
  699. def selective_scan_torch(
  700. u: torch.Tensor, # (B, K * C, L)
  701. delta: torch.Tensor, # (B, K * C, L)
  702. A: torch.Tensor, # (K * C, N)
  703. B: torch.Tensor, # (B, K, N, L)
  704. C: torch.Tensor, # (B, K, N, L)
  705. D: torch.Tensor = None, # (K * C)
  706. delta_bias: torch.Tensor = None, # (K * C)
  707. delta_softplus=True,
  708. oflex=True,
  709. *args,
  710. **kwargs
  711. ):
  712. dtype_in = u.dtype
  713. Batch, K, N, L = B.shape
  714. KCdim = u.shape[1]
  715. Cdim = int(KCdim / K)
  716. assert u.shape == (Batch, KCdim, L)
  717. assert delta.shape == (Batch, KCdim, L)
  718. assert A.shape == (KCdim, N)
  719. assert C.shape == B.shape
  720. if delta_bias is not None:
  721. delta = delta + delta_bias[..., None]
  722. if delta_softplus:
  723. delta = torch.nn.functional.softplus(delta)
  724. u, delta, A, B, C = u.float(), delta.float(), A.float(), B.float(), C.float()
  725. B = B.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
  726. C = C.view(Batch, K, 1, N, L).repeat(1, 1, Cdim, 1, 1).view(Batch, KCdim, N, L)
  727. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  728. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  729. if True:
  730. x = A.new_zeros((Batch, KCdim, N))
  731. ys = []
  732. for i in range(L):
  733. x = deltaA[:, :, i, :] * x + deltaB_u[:, :, i, :]
  734. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  735. ys.append(y)
  736. y = torch.stack(ys, dim=2) # (B, C, L)
  737. out = y if D is None else y + u * D.unsqueeze(-1)
  738. return out if oflex else out.to(dtype=dtype_in)
  739. class SelectiveScanCuda(torch.autograd.Function):
  740. @staticmethod
  741. @torch.cuda.amp.custom_fwd
  742. def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, oflex=True, backend=None):
  743. ctx.delta_softplus = delta_softplus
  744. # backend = "oflex" if WITH_SELECTIVESCAN_OFLEX and (backend is None) else backend
  745. # backend = "core" if WITH_SELECTIVESCAN_CORE and (backend is None) else backend
  746. backend = "mamba" if WITH_SELECTIVESCAN_MAMBA and (backend is None) else backend
  747. ctx.backend = backend
  748. if backend == "oflex":
  749. out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
  750. elif backend == "mamba":
  751. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
  752. ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
  753. return out
  754. @staticmethod
  755. @torch.cuda.amp.custom_bwd
  756. def backward(ctx, dout, *args):
  757. u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
  758. backend = ctx.backend
  759. if dout.stride(-1) != 1:
  760. dout = dout.contiguous()
  761. if backend == "oflex":
  762. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
  763. u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
  764. )
  765. elif backend == "mamba":
  766. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  767. u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
  768. False
  769. )
  770. return du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None
  771. def selective_scan_fn(
  772. u: torch.Tensor, # (B, K * C, L)
  773. delta: torch.Tensor, # (B, K * C, L)
  774. A: torch.Tensor, # (K * C, N)
  775. B: torch.Tensor, # (B, K, N, L)
  776. C: torch.Tensor, # (B, K, N, L)
  777. D: torch.Tensor = None, # (K * C)
  778. delta_bias: torch.Tensor = None, # (K * C)
  779. delta_softplus=True,
  780. oflex=True,
  781. backend=None,
  782. ):
  783. fn = selective_scan_torch if backend == "torch" or (not WITH_SELECTIVESCAN_MAMBA) else SelectiveScanCuda.apply
  784. return fn(u, delta, A, B, C, D, delta_bias, delta_softplus, oflex, backend)
  785. # fvcore flops =======================================
  786. def print_jit_input_names(inputs):
  787. print("input params: ", end=" ", flush=True)
  788. try:
  789. for i in range(10):
  790. print(inputs[i].debugName(), end=" ", flush=True)
  791. except Exception as e:
  792. pass
  793. print("", flush=True)
  794. def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
  795. """
  796. u: r(B D L)
  797. delta: r(B D L)
  798. A: r(D N)
  799. B: r(B N L)
  800. C: r(B N L)
  801. D: r(D)
  802. z: r(B D L)
  803. delta_bias: r(D), fp32
  804. ignores:
  805. [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
  806. """
  807. assert not with_complex
  808. # https://github.com/state-spaces/mamba/issues/110
  809. flops = 9 * B * L * D * N
  810. if with_D:
  811. flops += B * D * L
  812. if with_Z:
  813. flops += B * D * L
  814. return flops
  815. # this is only for selective_scan_ref...
  816. def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
  817. """
  818. u: r(B D L)
  819. delta: r(B D L)
  820. A: r(D N)
  821. B: r(B N L)
  822. C: r(B N L)
  823. D: r(D)
  824. z: r(B D L)
  825. delta_bias: r(D), fp32
  826. ignores:
  827. [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
  828. """
  829. import numpy as np
  830. # fvcore.nn.jit_handles
  831. def get_flops_einsum(input_shapes, equation):
  832. np_arrs = [np.zeros(s) for s in input_shapes]
  833. optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
  834. for line in optim.split("\n"):
  835. if "optimized flop" in line.lower():
  836. # divided by 2 because we count MAC (multiply-add counted as one flop)
  837. flop = float(np.floor(float(line.split(":")[-1]) / 2))
  838. return flop
  839. assert not with_complex
  840. flops = 0 # below code flops = 0
  841. flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
  842. if with_Group:
  843. flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
  844. else:
  845. flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
  846. in_for_flops = B * D * N
  847. if with_Group:
  848. in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
  849. else:
  850. in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
  851. flops += L * in_for_flops
  852. if with_D:
  853. flops += B * D * L
  854. if with_Z:
  855. flops += B * D * L
  856. return flops
  857. def selective_scan_flop_jit(inputs, outputs, backend="prefixsum", verbose=True):
  858. if verbose:
  859. print_jit_input_names(inputs)
  860. flops_fn = flops_selective_scan_ref if backend == "naive" else flops_selective_scan_fn
  861. B, D, L = inputs[0].type().sizes()
  862. N = inputs[2].type().sizes()[1]
  863. flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
  864. return flops
  865. # if __name__ == "__main__":
  866. # def params(B, K, C, N, L, device = torch.device("cuda"), itype = torch.float):
  867. # As = (-0.5 * torch.rand(K * C, N, device=device, dtype=torch.float32)).requires_grad_()
  868. # Bs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
  869. # Cs = torch.randn((B, K, N, L), device=device, dtype=itype).requires_grad_()
  870. # Ds = torch.randn((K * C), device=device, dtype=torch.float32).requires_grad_()
  871. # u = torch.randn((B, K * C, L), device=device, dtype=itype).requires_grad_()
  872. # delta = (0.5 * torch.rand((B, K * C, L), device=device, dtype=itype)).requires_grad_()
  873. # delta_bias = (0.5 * torch.rand((K * C), device=device, dtype=torch.float32)).requires_grad_()
  874. # return u, delta, As, Bs, Cs, Ds, delta_bias
  875. # def bench(func, xs, Warmup=30, NTimes=20):
  876. # import time
  877. # torch.cuda.synchronize()
  878. # for r in range(Warmup):
  879. # for x in xs:
  880. # func(x)
  881. # torch.cuda.synchronize()
  882. # tim0 = time.time()
  883. # for r in range(NTimes):
  884. # for x in xs:
  885. # func(x)
  886. # torch.cuda.synchronize()
  887. # return (time.time() - tim0) / NTimes
  888. # def check():
  889. # u, delta, As, Bs, Cs, Ds, delta_bias = params(1, 4, 16, 8, 512, itype=torch.float16)
  890. # u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1 = [x.clone().detach().requires_grad_() for x in [u, delta, As, Bs, Cs, Ds, delta_bias]]
  891. # # out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="torch")
  892. # out = selective_scan_fn(u1, delta1, As1, Bs1, Cs1, Ds1, delta_bias1, True, backend="oflex")
  893. # out_ref = selective_scan_fn(u, delta, As, Bs, Cs, Ds, delta_bias, True, backend="mamba")
  894. # print((out_ref - out).abs().max())
  895. # out.sum().backward()
  896. # out_ref.sum().backward()
  897. # for x, y in zip([u, As, Bs, Cs, Ds, delta, delta_bias], [u1, As1, Bs1, Cs1, Ds1, delta1, delta_bias1]):
  898. # print((x.grad - y.grad).abs().max())
  899. # u, delta, As, Bs, Cs, Ds, delta_bias = params(128, 4, 96, 8, 56 * 56)
  900. # print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="oflex"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  901. # print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="mamba"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  902. # print(bench(lambda x: selective_scan_fn(x[0], x[1], x[2], x[3], x[4], x[5], x[6], True, backend="torch"), [(u, delta, As, Bs, Cs, Ds, delta_bias),]))
  903. # check()
  904. ##########################################################
  905. # model.py
  906. ##########################################################
  907. import os
  908. import time
  909. import math
  910. import copy
  911. from functools import partial
  912. from typing import Optional, Callable, Any
  913. from collections import OrderedDict
  914. import torch
  915. import torch.nn as nn
  916. import torch.nn.functional as F
  917. import torch.utils.checkpoint as checkpoint
  918. from timm.models.layers import DropPath, trunc_normal_
  919. from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
  920. DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
  921. # =====================================================
  922. class Linear(nn.Linear):
  923. def __init__(self, *args, channel_first=False, groups=1, **kwargs):
  924. nn.Linear.__init__(self, *args, **kwargs)
  925. self.channel_first = channel_first
  926. self.groups = groups
  927. def forward(self, x: torch.Tensor):
  928. if self.channel_first:
  929. # B, C, H, W = x.shape
  930. if len(x.shape) == 4:
  931. return F.conv2d(x, self.weight[:, :, None, None], self.bias, groups=self.groups)
  932. elif len(x.shape) == 3:
  933. return F.conv1d(x, self.weight[:, :, None], self.bias, groups=self.groups)
  934. else:
  935. return F.linear(x, self.weight, self.bias)
  936. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  937. self_state_dict = self.state_dict()
  938. load_state_dict_keys = list(state_dict.keys())
  939. if prefix + "weight" in load_state_dict_keys:
  940. state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view_as(self_state_dict["weight"])
  941. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  942. class LayerNorm(nn.LayerNorm):
  943. def __init__(self, *args, channel_first=None, in_channel_first=False, out_channel_first=False, **kwargs):
  944. nn.LayerNorm.__init__(self, *args, **kwargs)
  945. if channel_first is not None:
  946. in_channel_first = channel_first
  947. out_channel_first = channel_first
  948. self.in_channel_first = in_channel_first
  949. self.out_channel_first = out_channel_first
  950. def forward(self, x: torch.Tensor):
  951. if self.in_channel_first:
  952. x = x.permute(0, 2, 3, 1)
  953. x = nn.LayerNorm.forward(self, x)
  954. if self.out_channel_first:
  955. x = x.permute(0, 3, 1, 2)
  956. return x
  957. class PatchMerge(nn.Module):
  958. def __init__(self, channel_first=True, in_channel_first=False, out_channel_first=False,):
  959. nn.Module.__init__(self)
  960. if channel_first is not None:
  961. in_channel_first = channel_first
  962. out_channel_first = channel_first
  963. self.in_channel_first = in_channel_first
  964. self.out_channel_first = out_channel_first
  965. # print(f"WARNING: output [(0, 0), (1, 0), (0, 1), (1, 1)] for (H, W).")
  966. def forward(self, x: torch.Tensor):
  967. B, C, H, W = x.shape
  968. if not self.in_channel_first:
  969. B, H, W, C = x.shape
  970. if (W % 2 != 0) or (H % 2 != 0):
  971. PH, PW = H - H % 2, W - W % 2
  972. pad_shape = (PW // 2, PW - PW // 2, PH // 2, PH - PH // 2)
  973. pad_shape = (*pad_shape, 0, 0, 0, 0) if self.in_channel_first else (0, 0, *pad_shape, 0, 0)
  974. x = nn.functional.pad(x, pad_shape)
  975. xs = [
  976. x[..., 0::2, 0::2], x[..., 1::2, 0::2],
  977. x[..., 0::2, 1::2], x[..., 1::2, 1::2],
  978. ] if self.in_channel_first else [
  979. x[..., 0::2, 0::2, :], x[..., 1::2, 0::2, :],
  980. x[..., 0::2, 1::2, :], x[..., 1::2, 1::2, :],
  981. ]
  982. xs = torch.cat(xs, (1 if self.out_channel_first else -1))
  983. return xs
  984. class Permute(nn.Module):
  985. def __init__(self, *args):
  986. super().__init__()
  987. self.args = args
  988. def forward(self, x: torch.Tensor):
  989. return x.permute(*self.args)
  990. class SoftmaxSpatial(nn.Softmax):
  991. def forward(self, x: torch.Tensor):
  992. if self.dim == -1:
  993. B, C, H, W = x.shape
  994. return super().forward(x.view(B, C, -1)).view(B, C, H, W)
  995. elif self.dim == 1:
  996. B, H, W, C = x.shape
  997. return super().forward(x.view(B, -1, C)).view(B, H, W, C)
  998. else:
  999. raise NotImplementedError
  1000. class Mlp(nn.Module):
  1001. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channel_first=False):
  1002. super().__init__()
  1003. out_features = out_features or in_features
  1004. hidden_features = hidden_features or in_features
  1005. self.fc1 = Linear(in_features, hidden_features, channel_first=channel_first)
  1006. self.act = act_layer()
  1007. self.fc2 = Linear(hidden_features, out_features, channel_first=channel_first)
  1008. self.drop = nn.Dropout(drop)
  1009. def forward(self, x):
  1010. x = self.fc1(x)
  1011. x = self.act(x)
  1012. x = self.drop(x)
  1013. x = self.fc2(x)
  1014. x = self.drop(x)
  1015. return x
  1016. class mamba_init:
  1017. @staticmethod
  1018. def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4):
  1019. dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
  1020. # Initialize special dt projection to preserve variance at initialization
  1021. dt_init_std = dt_rank**-0.5 * dt_scale
  1022. if dt_init == "constant":
  1023. nn.init.constant_(dt_proj.weight, dt_init_std)
  1024. elif dt_init == "random":
  1025. nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
  1026. else:
  1027. raise NotImplementedError
  1028. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  1029. dt = torch.exp(
  1030. torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min))
  1031. + math.log(dt_min)
  1032. ).clamp(min=dt_init_floor)
  1033. # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  1034. inv_dt = dt + torch.log(-torch.expm1(-dt))
  1035. with torch.no_grad():
  1036. dt_proj.bias.copy_(inv_dt)
  1037. # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  1038. # dt_proj.bias._no_reinit = True
  1039. return dt_proj
  1040. @staticmethod
  1041. def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
  1042. # S4D real initialization
  1043. A = torch.arange(1, d_state + 1, dtype=torch.float32, device=device).view(1, -1).repeat(d_inner, 1).contiguous()
  1044. A_log = torch.log(A) # Keep A_log in fp32
  1045. if copies > 0:
  1046. A_log = A_log[None].repeat(copies, 1, 1).contiguous()
  1047. if merge:
  1048. A_log = A_log.flatten(0, 1)
  1049. A_log = nn.Parameter(A_log)
  1050. A_log._no_weight_decay = True
  1051. return A_log
  1052. @staticmethod
  1053. def D_init(d_inner, copies=-1, device=None, merge=True):
  1054. # D "skip" parameter
  1055. D = torch.ones(d_inner, device=device)
  1056. if copies > 0:
  1057. D = D[None].repeat(copies, 1).contiguous()
  1058. if merge:
  1059. D = D.flatten(0, 1)
  1060. D = nn.Parameter(D) # Keep in fp32
  1061. D._no_weight_decay = True
  1062. return D
  1063. @classmethod
  1064. def init_dt_A_D(cls, d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4):
  1065. # dt proj ============================
  1066. dt_projs = [
  1067. cls.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor)
  1068. for _ in range(k_group)
  1069. ]
  1070. dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in dt_projs], dim=0)) # (K, inner, rank)
  1071. dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in dt_projs], dim=0)) # (K, inner)
  1072. del dt_projs
  1073. # A, D =======================================
  1074. A_logs = cls.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
  1075. Ds = cls.D_init(d_inner, copies=k_group, merge=True) # (K * D)
  1076. return A_logs, Ds, dt_projs_weight, dt_projs_bias
  1077. # support: v0, v0seq
  1078. class SS2Dv0:
  1079. def __initv0__(
  1080. self,
  1081. # basic dims ===========
  1082. d_model=96,
  1083. d_state=16,
  1084. ssm_ratio=2.0,
  1085. dt_rank="auto",
  1086. # ======================
  1087. dropout=0.0,
  1088. # ======================
  1089. seq=False,
  1090. force_fp32=True,
  1091. **kwargs,
  1092. ):
  1093. if "channel_first" in kwargs:
  1094. assert not kwargs["channel_first"]
  1095. act_layer = nn.SiLU
  1096. dt_min = 0.001
  1097. dt_max = 0.1
  1098. dt_init = "random"
  1099. dt_scale = 1.0
  1100. dt_init_floor = 1e-4
  1101. bias = False
  1102. conv_bias = True
  1103. d_conv = 3
  1104. k_group = 4
  1105. factory_kwargs = {"device": None, "dtype": None}
  1106. super().__init__()
  1107. d_inner = int(ssm_ratio * d_model)
  1108. dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
  1109. self.forward = self.forwardv0
  1110. if seq:
  1111. self.forward = partial(self.forwardv0, seq=True)
  1112. if not force_fp32:
  1113. self.forward = partial(self.forwardv0, force_fp32=False)
  1114. # in proj ============================
  1115. self.in_proj = nn.Linear(d_model, d_inner * 2, bias=bias)
  1116. self.act: nn.Module = act_layer()
  1117. self.conv2d = nn.Conv2d(
  1118. in_channels=d_inner,
  1119. out_channels=d_inner,
  1120. groups=d_inner,
  1121. bias=conv_bias,
  1122. kernel_size=d_conv,
  1123. padding=(d_conv - 1) // 2,
  1124. **factory_kwargs,
  1125. )
  1126. # x proj ============================
  1127. self.x_proj = [
  1128. nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False)
  1129. for _ in range(k_group)
  1130. ]
  1131. self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  1132. del self.x_proj
  1133. # dt proj, A, D ============================
  1134. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  1135. d_state, dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=4,
  1136. )
  1137. # out proj =======================================
  1138. self.out_norm = nn.LayerNorm(d_inner)
  1139. self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
  1140. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  1141. def forwardv0(self, x: torch.Tensor, seq=False, force_fp32=True, **kwargs):
  1142. x = self.in_proj(x)
  1143. x, z = x.chunk(2, dim=-1) # (b, h, w, d)
  1144. z = self.act(z)
  1145. x = x.permute(0, 3, 1, 2).contiguous()
  1146. x = self.conv2d(x) # (b, d, h, w)
  1147. x = self.act(x)
  1148. selective_scan = partial(selective_scan_fn, backend="mamba")
  1149. B, D, H, W = x.shape
  1150. D, N = self.A_logs.shape
  1151. K, D, R = self.dt_projs_weight.shape
  1152. L = H * W
  1153. x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
  1154. xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
  1155. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
  1156. if hasattr(self, "x_proj_bias"):
  1157. x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
  1158. dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
  1159. dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
  1160. xs = xs.view(B, -1, L) # (b, k * d, l)
  1161. dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
  1162. Bs = Bs.contiguous() # (b, k, d_state, l)
  1163. Cs = Cs.contiguous() # (b, k, d_state, l)
  1164. As = -self.A_logs.float().exp() # (k * d, d_state)
  1165. Ds = self.Ds.float() # (k * d)
  1166. dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
  1167. # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
  1168. # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
  1169. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  1170. if force_fp32:
  1171. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  1172. if seq:
  1173. out_y = []
  1174. for i in range(4):
  1175. yi = selective_scan(
  1176. xs.view(B, K, -1, L)[:, i], dts.view(B, K, -1, L)[:, i],
  1177. As.view(K, -1, N)[i], Bs[:, i].unsqueeze(1), Cs[:, i].unsqueeze(1), Ds.view(K, -1)[i],
  1178. delta_bias=dt_projs_bias.view(K, -1)[i],
  1179. delta_softplus=True,
  1180. ).view(B, -1, L)
  1181. out_y.append(yi)
  1182. out_y = torch.stack(out_y, dim=1)
  1183. else:
  1184. out_y = selective_scan(
  1185. xs, dts,
  1186. As, Bs, Cs, Ds,
  1187. delta_bias=dt_projs_bias,
  1188. delta_softplus=True,
  1189. ).view(B, K, -1, L)
  1190. assert out_y.dtype == torch.float
  1191. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  1192. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  1193. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  1194. y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  1195. y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
  1196. y = self.out_norm(y).view(B, H, W, -1)
  1197. y = y * z
  1198. out = self.dropout(self.out_proj(y))
  1199. return out
  1200. # support: v01-v05; v051d,v052d,v052dc;
  1201. # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32;
  1202. # history support: v2,v3;v31d,v32d,v32dc;
  1203. class SS2Dv2:
  1204. def __initv2__(
  1205. self,
  1206. # basic dims ===========
  1207. d_model=96,
  1208. d_state=16,
  1209. ssm_ratio=2.0,
  1210. dt_rank="auto",
  1211. act_layer=nn.SiLU,
  1212. # dwconv ===============
  1213. d_conv=3, # < 2 means no conv
  1214. conv_bias=True,
  1215. # ======================
  1216. dropout=0.0,
  1217. bias=False,
  1218. # dt init ==============
  1219. dt_min=0.001,
  1220. dt_max=0.1,
  1221. dt_init="random",
  1222. dt_scale=1.0,
  1223. dt_init_floor=1e-4,
  1224. initialize="v0",
  1225. # ======================
  1226. forward_type="v2",
  1227. channel_first=False,
  1228. # ======================
  1229. **kwargs,
  1230. ):
  1231. factory_kwargs = {"device": None, "dtype": None}
  1232. super().__init__()
  1233. self.k_group = 4
  1234. self.d_model = int(d_model)
  1235. self.d_state = int(d_state)
  1236. self.d_inner = int(ssm_ratio * d_model)
  1237. self.dt_rank = int(math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank)
  1238. self.channel_first = channel_first
  1239. self.with_dconv = d_conv > 1
  1240. self.forward = self.forwardv2
  1241. # tags for forward_type ==============================
  1242. checkpostfix = self.checkpostfix
  1243. self.disable_force32, forward_type = checkpostfix("_no32", forward_type)
  1244. self.oact, forward_type = checkpostfix("_oact", forward_type)
  1245. self.disable_z, forward_type = checkpostfix("_noz", forward_type)
  1246. self.disable_z_act, forward_type = checkpostfix("_nozact", forward_type)
  1247. self.out_norm, forward_type = self.get_outnorm(forward_type, self.d_inner, channel_first)
  1248. # forward_type debug =======================================
  1249. FORWARD_TYPES = dict(
  1250. v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba", scan_force_torch=True),
  1251. v02=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="mamba"),
  1252. v03=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="oflex"),
  1253. v04=partial(self.forward_corev2, force_fp32=False), # selective_scan_backend="oflex", scan_mode="cross2d"
  1254. v05=partial(self.forward_corev2, force_fp32=False, no_einsum=True), # selective_scan_backend="oflex", scan_mode="cross2d"
  1255. # ===============================
  1256. v051d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="unidi"),
  1257. v052d=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="bidi"),
  1258. v052dc=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode="cascade2d"),
  1259. v052d3=partial(self.forward_corev2, force_fp32=False, no_einsum=True, scan_mode=3), # debug
  1260. # ===============================
  1261. v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), selective_scan_backend="core"),
  1262. v3=partial(self.forward_corev2, force_fp32=False, selective_scan_backend="oflex"),
  1263. )
  1264. self.forward_core = FORWARD_TYPES.get(forward_type, None)
  1265. # in proj =======================================
  1266. d_proj = self.d_inner if self.disable_z else (self.d_inner * 2)
  1267. self.in_proj = Linear(self.d_model, d_proj, bias=bias, channel_first=channel_first)
  1268. self.act: nn.Module = act_layer()
  1269. # conv =======================================
  1270. if self.with_dconv:
  1271. self.conv2d = nn.Conv2d(
  1272. in_channels=self.d_inner,
  1273. out_channels=self.d_inner,
  1274. groups=self.d_inner,
  1275. bias=conv_bias,
  1276. kernel_size=d_conv,
  1277. padding=(d_conv - 1) // 2,
  1278. **factory_kwargs,
  1279. )
  1280. # x proj ============================
  1281. self.x_proj = Linear(self.d_inner, self.k_group * (self.dt_rank + self.d_state * 2), groups=self.k_group, bias=False, channel_first=True)
  1282. self.dt_projs = Linear(self.dt_rank, self.k_group * self.d_inner, groups=self.k_group, bias=False, channel_first=True)
  1283. # self.x_proj = [
  1284. # nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False)
  1285. # for _ in range(self.k_group)
  1286. # ]
  1287. # self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
  1288. # del self.x_proj
  1289. # out proj =======================================
  1290. self.out_act = nn.GELU() if self.oact else nn.Identity()
  1291. self.out_proj = Linear(self.d_inner, self.d_model, bias=bias, channel_first=channel_first)
  1292. self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
  1293. if initialize in ["v0"]:
  1294. self.A_logs, self.Ds, self.dt_projs_weight, self.dt_projs_bias = mamba_init.init_dt_A_D(
  1295. self.d_state, self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, k_group=self.k_group,
  1296. )
  1297. elif initialize in ["v1"]:
  1298. # simple init dt_projs, A_logs, Ds
  1299. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  1300. self.A_logs = nn.Parameter(torch.randn((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  1301. self.dt_projs_weight = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner, self.dt_rank))) # 0.1 is added in 0430
  1302. self.dt_projs_bias = nn.Parameter(0.1 * torch.randn((self.k_group, self.d_inner))) # 0.1 is added in 0430
  1303. elif initialize in ["v2"]:
  1304. # simple init dt_projs, A_logs, Ds
  1305. self.Ds = nn.Parameter(torch.ones((self.k_group * self.d_inner)))
  1306. self.A_logs = nn.Parameter(torch.zeros((self.k_group * self.d_inner, self.d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
  1307. self.dt_projs_weight = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner, self.dt_rank)))
  1308. self.dt_projs_bias = nn.Parameter(0.1 * torch.rand((self.k_group, self.d_inner)))
  1309. self.dt_projs.weight.data = self.dt_projs_weight.data.view(self.dt_projs.weight.shape)
  1310. # self.dt_projs.bias.data = self.dt_projs_bias.data.view(self.dt_projs.bias.shape)
  1311. del self.dt_projs_weight
  1312. # del self.dt_projs_bias
  1313. def forward_corev2(
  1314. self,
  1315. x: torch.Tensor=None,
  1316. # ==============================
  1317. force_fp32=False, # True: input fp32
  1318. # ==============================
  1319. ssoflex=True, # True: input 16 or 32 output 32 False: output dtype as input
  1320. # ==============================
  1321. selective_scan_backend = None,
  1322. # ==============================
  1323. scan_mode = "cross2d",
  1324. scan_force_torch = False,
  1325. # ==============================
  1326. **kwargs,
  1327. ):
  1328. assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
  1329. _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
  1330. assert isinstance(_scan_mode, int)
  1331. delta_softplus = True
  1332. channel_first = self.channel_first
  1333. to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args)
  1334. force_fp32 = force_fp32 or ((not ssoflex) and self.training)
  1335. B, D, H, W = x.shape
  1336. N = self.d_state
  1337. K, D, R = self.k_group, self.d_inner, self.dt_rank
  1338. L = H * W
  1339. def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
  1340. return selective_scan_fn(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, backend=selective_scan_backend)
  1341. if True:
  1342. xs = cross_scan_fn(x, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  1343. x_dbl = self.x_proj(xs.view(B, -1, L))
  1344. dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2)
  1345. dts = dts.contiguous().view(B, -1, L)
  1346. dts = self.dt_projs(dts)
  1347. xs = xs.view(B, -1, L)
  1348. dts = dts.contiguous().view(B, -1, L)
  1349. As = -self.A_logs.to(torch.float).exp() # (k * c, d_state)
  1350. Ds = self.Ds.to(torch.float) # (K * c)
  1351. Bs = Bs.contiguous().view(B, K, N, L)
  1352. Cs = Cs.contiguous().view(B, K, N, L)
  1353. delta_bias = self.dt_projs_bias.view(-1).to(torch.float)
  1354. if force_fp32:
  1355. xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs)
  1356. ys: torch.Tensor = selective_scan(
  1357. xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus
  1358. ).view(B, K, -1, H, W)
  1359. y: torch.Tensor = cross_merge_fn(ys, in_channel_first=True, out_channel_first=True, scans=_scan_mode, force_torch=scan_force_torch)
  1360. if getattr(self, "__DEBUG__", False):
  1361. setattr(self, "__data__", dict(
  1362. A_logs=self.A_logs, Bs=Bs, Cs=Cs, Ds=Ds,
  1363. us=xs, dts=dts, delta_bias=delta_bias,
  1364. ys=ys, y=y, H=H, W=W,
  1365. ))
  1366. y = y.view(B, -1, H, W)
  1367. if not channel_first:
  1368. y = y.permute(0, 2, 3, 1).contiguous()
  1369. y = self.out_norm(y)
  1370. return y.to(x.dtype)
  1371. def forwardv2(self, x: torch.Tensor, **kwargs):
  1372. x = self.in_proj(x)
  1373. if not self.disable_z:
  1374. x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) # (b, h, w, d)
  1375. if not self.disable_z_act:
  1376. z = self.act(z)
  1377. if not self.channel_first:
  1378. x = x.permute(0, 3, 1, 2).contiguous()
  1379. if self.with_dconv:
  1380. x = self.conv2d(x) # (b, d, h, w)
  1381. x = self.act(x)
  1382. y = self.forward_core(x)
  1383. y = self.out_act(y)
  1384. if not self.disable_z:
  1385. y = y * z
  1386. out = self.dropout(self.out_proj(y))
  1387. return out
  1388. @staticmethod
  1389. def get_outnorm(forward_type="", d_inner=192, channel_first=True):
  1390. def checkpostfix(tag, value):
  1391. ret = value[-len(tag):] == tag
  1392. if ret:
  1393. value = value[:-len(tag)]
  1394. return ret, value
  1395. out_norm_none, forward_type = checkpostfix("_onnone", forward_type)
  1396. out_norm_dwconv3, forward_type = checkpostfix("_ondwconv3", forward_type)
  1397. out_norm_cnorm, forward_type = checkpostfix("_oncnorm", forward_type)
  1398. out_norm_softmax, forward_type = checkpostfix("_onsoftmax", forward_type)
  1399. out_norm_sigmoid, forward_type = checkpostfix("_onsigmoid", forward_type)
  1400. out_norm = nn.Identity()
  1401. if out_norm_none:
  1402. out_norm = nn.Identity()
  1403. elif out_norm_cnorm:
  1404. out_norm = nn.Sequential(
  1405. LayerNorm(d_inner, channel_first=channel_first),
  1406. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1407. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  1408. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1409. )
  1410. elif out_norm_dwconv3:
  1411. out_norm = nn.Sequential(
  1412. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1413. nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False),
  1414. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1415. )
  1416. elif out_norm_softmax:
  1417. out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1))
  1418. elif out_norm_sigmoid:
  1419. out_norm = nn.Sigmoid()
  1420. else:
  1421. out_norm = LayerNorm(d_inner, channel_first=channel_first)
  1422. return out_norm, forward_type
  1423. @staticmethod
  1424. def checkpostfix(tag, value):
  1425. ret = value[-len(tag):] == tag
  1426. if ret:
  1427. value = value[:-len(tag)]
  1428. return ret, value
  1429. class SS2D(nn.Module, SS2Dv0, SS2Dv2):
  1430. def __init__(
  1431. self,
  1432. # basic dims ===========
  1433. d_model=96,
  1434. d_state=16,
  1435. ssm_ratio=2.0,
  1436. dt_rank="auto",
  1437. act_layer=nn.SiLU,
  1438. # dwconv ===============
  1439. d_conv=3, # < 2 means no conv
  1440. conv_bias=True,
  1441. # ======================
  1442. dropout=0.0,
  1443. bias=False,
  1444. # dt init ==============
  1445. dt_min=0.001,
  1446. dt_max=0.1,
  1447. dt_init="random",
  1448. dt_scale=1.0,
  1449. dt_init_floor=1e-4,
  1450. initialize="v0",
  1451. # ======================
  1452. forward_type="v2",
  1453. channel_first=False,
  1454. # ======================
  1455. **kwargs,
  1456. ):
  1457. nn.Module.__init__(self)
  1458. kwargs.update(
  1459. d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank,
  1460. act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias,
  1461. dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor,
  1462. initialize=initialize, forward_type=forward_type, channel_first=channel_first,
  1463. )
  1464. if forward_type in ["v0", "v0seq"]:
  1465. self.__initv0__(seq=("seq" in forward_type), **kwargs)
  1466. elif forward_type.startswith("xv"):
  1467. self.__initxv__(**kwargs)
  1468. elif forward_type.startswith("m"):
  1469. self.__initm0__(**kwargs)
  1470. else:
  1471. self.__initv2__(**kwargs)
  1472. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  1473. self_state_dict = self.state_dict()
  1474. self_state_dict_keys = list(self.state_dict().keys())
  1475. load_state_dict_keys = list(state_dict.keys())
  1476. names = {
  1477. "x_proj_weight": "x_proj.weight",
  1478. "x_proj_bias": "x_proj.bias",
  1479. "dt_projs_weight": "dt_projs.weight",
  1480. "dt_projs_bias": "dt_projs.bias",
  1481. }
  1482. for k, v in names.items():
  1483. if (prefix + k in load_state_dict_keys) and (k not in self_state_dict_keys):
  1484. assert v in self_state_dict_keys, f"{v} not in state_dict."
  1485. state_dict[prefix + v] = state_dict[prefix + k].view_as(self_state_dict[v])
  1486. state_dict.pop(prefix + k)
  1487. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  1488. # =====================================================
  1489. class VSSBlock(nn.Module):
  1490. def __init__(
  1491. self,
  1492. hidden_dim: int = 0,
  1493. drop_path: float = 0,
  1494. channel_first=False,
  1495. # =============================
  1496. ssm_d_state: int = 16,
  1497. ssm_ratio=2.0,
  1498. ssm_dt_rank: Any = "auto",
  1499. ssm_act_layer=nn.SiLU,
  1500. ssm_conv: int = 3,
  1501. ssm_conv_bias=True,
  1502. ssm_drop_rate: float = 0,
  1503. ssm_init="v0",
  1504. forward_type="v2",
  1505. # =============================
  1506. mlp_ratio=4.0,
  1507. mlp_act_layer=nn.GELU,
  1508. mlp_drop_rate: float = 0.0,
  1509. # =============================
  1510. use_checkpoint: bool = False,
  1511. post_norm: bool = False,
  1512. # =============================
  1513. **kwargs,
  1514. ):
  1515. super().__init__()
  1516. self.ssm_branch = ssm_ratio > 0
  1517. self.mlp_branch = mlp_ratio > 0
  1518. self.use_checkpoint = use_checkpoint
  1519. self.post_norm = post_norm
  1520. if self.ssm_branch:
  1521. self.norm = LayerNorm(hidden_dim, channel_first=channel_first)
  1522. self.op = SS2D(
  1523. d_model=hidden_dim,
  1524. d_state=ssm_d_state,
  1525. ssm_ratio=ssm_ratio,
  1526. dt_rank=ssm_dt_rank,
  1527. act_layer=ssm_act_layer,
  1528. # ==========================
  1529. d_conv=ssm_conv,
  1530. conv_bias=ssm_conv_bias,
  1531. # ==========================
  1532. dropout=ssm_drop_rate,
  1533. # bias=False,
  1534. # ==========================
  1535. # dt_min=0.001,
  1536. # dt_max=0.1,
  1537. # dt_init="random",
  1538. # dt_scale="random",
  1539. # dt_init_floor=1e-4,
  1540. initialize=ssm_init,
  1541. # ==========================
  1542. forward_type=forward_type,
  1543. channel_first=channel_first,
  1544. )
  1545. self.drop_path = DropPath(drop_path)
  1546. if self.mlp_branch:
  1547. self.norm2 = LayerNorm(hidden_dim, channel_first=channel_first)
  1548. mlp_hidden_dim = int(hidden_dim * mlp_ratio)
  1549. self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channel_first=channel_first)
  1550. def _forward(self, input: torch.Tensor):
  1551. x = input
  1552. if self.ssm_branch:
  1553. if self.post_norm:
  1554. x = x + self.drop_path(self.norm(self.op(x)))
  1555. else:
  1556. x = x + self.drop_path(self.op(self.norm(x)))
  1557. if self.mlp_branch:
  1558. if self.post_norm:
  1559. x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
  1560. else:
  1561. x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
  1562. return x
  1563. def forward(self, input: torch.Tensor):
  1564. if self.use_checkpoint:
  1565. return checkpoint.checkpoint(self._forward, input)
  1566. else:
  1567. return self._forward(input)
  1568. class VSSM(nn.Module):
  1569. def __init__(
  1570. self,
  1571. patch_size=4,
  1572. in_chans=3,
  1573. num_classes=1000,
  1574. depths=[2, 2, 9, 2],
  1575. dims=[96, 192, 384, 768],
  1576. # =========================
  1577. ssm_d_state=16,
  1578. ssm_ratio=2.0,
  1579. ssm_dt_rank="auto",
  1580. ssm_act_layer="silu",
  1581. ssm_conv=3,
  1582. ssm_conv_bias=True,
  1583. ssm_drop_rate=0.0,
  1584. ssm_init="v0",
  1585. forward_type="v2",
  1586. # =========================
  1587. mlp_ratio=4.0,
  1588. mlp_act_layer="gelu",
  1589. mlp_drop_rate=0.0,
  1590. gmlp=False,
  1591. # =========================
  1592. drop_path_rate=0.1,
  1593. patch_norm=True,
  1594. norm_layer="LN", # "BN", "LN2D"
  1595. downsample_version: str = "v2", # "v1", "v2", "v3"
  1596. patchembed_version: str = "v1", # "v1", "v2"
  1597. use_checkpoint=False,
  1598. # =========================
  1599. posembed=False,
  1600. imgsize=224,
  1601. _SS2D=SS2D,
  1602. # =========================
  1603. **kwargs,
  1604. ):
  1605. super().__init__()
  1606. self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
  1607. self.num_classes = num_classes
  1608. self.num_layers = len(depths)
  1609. if isinstance(dims, int):
  1610. dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
  1611. self.num_features = dims[-1]
  1612. self.dims = dims
  1613. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  1614. _ACTLAYERS = dict(
  1615. silu=nn.SiLU,
  1616. gelu=nn.GELU,
  1617. relu=nn.ReLU,
  1618. sigmoid=nn.Sigmoid,
  1619. )
  1620. ssm_act_layer: nn.Module = _ACTLAYERS.get(ssm_act_layer.lower(), None)
  1621. mlp_act_layer: nn.Module = _ACTLAYERS.get(mlp_act_layer.lower(), None)
  1622. self.pos_embed = self._pos_embed(dims[0], patch_size, imgsize) if posembed else None
  1623. self.patch_embed = self._make_patch_embed(in_chans, dims[0], patch_size, patch_norm, channel_first=self.channel_first, version=patchembed_version)
  1624. self.layers = nn.ModuleList()
  1625. for i_layer in range(self.num_layers):
  1626. downsample = self._make_downsample(
  1627. self.dims[i_layer],
  1628. self.dims[i_layer + 1],
  1629. channel_first=self.channel_first,
  1630. version=downsample_version,
  1631. ) if (i_layer < self.num_layers - 1) else nn.Identity()
  1632. self.layers.append(self._make_layer(
  1633. dim = self.dims[i_layer],
  1634. drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  1635. use_checkpoint=use_checkpoint,
  1636. downsample=downsample,
  1637. channel_first=self.channel_first,
  1638. # =================
  1639. ssm_d_state=ssm_d_state,
  1640. ssm_ratio=ssm_ratio,
  1641. ssm_dt_rank=ssm_dt_rank,
  1642. ssm_act_layer=ssm_act_layer,
  1643. ssm_conv=ssm_conv,
  1644. ssm_conv_bias=ssm_conv_bias,
  1645. ssm_drop_rate=ssm_drop_rate,
  1646. ssm_init=ssm_init,
  1647. forward_type=forward_type,
  1648. # =================
  1649. mlp_ratio=mlp_ratio,
  1650. mlp_act_layer=mlp_act_layer,
  1651. mlp_drop_rate=mlp_drop_rate,
  1652. gmlp=gmlp,
  1653. # =================
  1654. _SS2D=_SS2D,
  1655. ))
  1656. self.classifier = nn.Sequential(OrderedDict(
  1657. norm=LayerNorm(self.num_features, channel_first=self.channel_first), # B,H,W,C
  1658. permute=(Permute(0, 3, 1, 2) if not self.channel_first else nn.Identity()),
  1659. avgpool=nn.AdaptiveAvgPool2d(1),
  1660. flatten=nn.Flatten(1),
  1661. head=nn.Linear(self.num_features, num_classes),
  1662. ))
  1663. self.apply(self._init_weights)
  1664. @staticmethod
  1665. def _pos_embed(embed_dims, patch_size, img_size):
  1666. patch_height, patch_width = (img_size // patch_size, img_size // patch_size)
  1667. pos_embed = nn.Parameter(torch.zeros(1, embed_dims, patch_height, patch_width))
  1668. trunc_normal_(pos_embed, std=0.02)
  1669. return pos_embed
  1670. def _init_weights(self, m: nn.Module):
  1671. if isinstance(m, nn.Linear):
  1672. trunc_normal_(m.weight, std=.02)
  1673. if isinstance(m, nn.Linear) and m.bias is not None:
  1674. nn.init.constant_(m.bias, 0)
  1675. elif isinstance(m, nn.LayerNorm):
  1676. nn.init.constant_(m.bias, 0)
  1677. nn.init.constant_(m.weight, 1.0)
  1678. # used in building optimizer
  1679. @torch.jit.ignore
  1680. def no_weight_decay(self):
  1681. return {"pos_embed"}
  1682. # used in building optimizer
  1683. @torch.jit.ignore
  1684. def no_weight_decay_keywords(self):
  1685. return {}
  1686. @staticmethod
  1687. def _make_patch_embed(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, channel_first=False, version="v1"):
  1688. # if channel first, then Norm and Output are both channel_first
  1689. if version == "v1": # simple patch_embed, same with swin transformer
  1690. return nn.Sequential(
  1691. nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True),
  1692. nn.Identity(),
  1693. (LayerNorm(embed_dim, in_channel_first=True, out_channel_first=channel_first)
  1694. if patch_norm else (nn.Identity() if channel_first else Permute(0, 2, 3, 1))),
  1695. )
  1696. elif version == "v2": # patch embed with stacked conv2d
  1697. stride = patch_size // 2
  1698. kernel_size = stride + 1
  1699. padding = 1
  1700. return nn.Sequential(
  1701. nn.Conv2d(in_chans, embed_dim // 2, kernel_size=kernel_size, stride=stride, padding=padding),
  1702. nn.Identity(),
  1703. (LayerNorm(embed_dim // 2, channel_first=True) if patch_norm else nn.Identity()),
  1704. nn.Identity(),
  1705. nn.GELU(),
  1706. nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
  1707. nn.Identity(),
  1708. (LayerNorm(embed_dim, in_channel_first=True, out_channel_first=channel_first)
  1709. if patch_norm else (nn.Identity() if channel_first else Permute(0, 2, 3, 1))),
  1710. )
  1711. raise NotImplementedError
  1712. @staticmethod
  1713. def _make_downsample(dim=96, out_dim=192, norm=True, channel_first=False, version="v1"):
  1714. # if channel first, then Norm and Output are both channel_first
  1715. if version == "v1": # patch merging from swin transformer
  1716. # return PatchMerging2D(dim, 2 * dim, norm_layer, False)
  1717. return nn.Sequential(
  1718. PatchMerge(channel_first),
  1719. LayerNorm(4 * dim, channel_first=channel_first) if norm else nn.Identity(),
  1720. Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False, channel_first=channel_first),
  1721. )
  1722. elif version == "v2": # combine pixelunshuffle and linear into conv2d
  1723. return nn.Sequential(
  1724. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1725. nn.Conv2d(dim, out_dim, kernel_size=2, stride=2),
  1726. nn.Identity(),
  1727. LayerNorm(out_dim, in_channel_first=True, out_channel_first=channel_first) if norm else
  1728. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1729. )
  1730. elif version == "v3": # conv2d with overlap
  1731. return nn.Sequential(
  1732. (nn.Identity() if channel_first else Permute(0, 3, 1, 2)),
  1733. nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
  1734. nn.Identity(),
  1735. LayerNorm(out_dim, in_channel_first=True, out_channel_first=channel_first) if norm else
  1736. (nn.Identity() if channel_first else Permute(0, 2, 3, 1)),
  1737. )
  1738. raise NotImplementedError
  1739. @staticmethod
  1740. def _make_layer(
  1741. dim=96,
  1742. drop_path=[0.1, 0.1],
  1743. use_checkpoint=False,
  1744. downsample=nn.Identity(),
  1745. channel_first=False,
  1746. # ===========================
  1747. ssm_d_state=16,
  1748. ssm_ratio=2.0,
  1749. ssm_dt_rank="auto",
  1750. ssm_act_layer=nn.SiLU,
  1751. ssm_conv=3,
  1752. ssm_conv_bias=True,
  1753. ssm_drop_rate=0.0,
  1754. ssm_init="v0",
  1755. forward_type="v2",
  1756. # ===========================
  1757. mlp_ratio=4.0,
  1758. mlp_act_layer=nn.GELU,
  1759. mlp_drop_rate=0.0,
  1760. # ===========================
  1761. **kwargs,
  1762. ):
  1763. # if channel first, then Norm and Output are both channel_first
  1764. depth = len(drop_path)
  1765. blocks = []
  1766. for d in range(depth):
  1767. blocks.append(VSSBlock(
  1768. hidden_dim=dim,
  1769. drop_path=drop_path[d],
  1770. channel_first=channel_first,
  1771. ssm_d_state=ssm_d_state,
  1772. ssm_ratio=ssm_ratio,
  1773. ssm_dt_rank=ssm_dt_rank,
  1774. ssm_act_layer=ssm_act_layer,
  1775. ssm_conv=ssm_conv,
  1776. ssm_conv_bias=ssm_conv_bias,
  1777. ssm_drop_rate=ssm_drop_rate,
  1778. ssm_init=ssm_init,
  1779. forward_type=forward_type,
  1780. mlp_ratio=mlp_ratio,
  1781. mlp_act_layer=mlp_act_layer,
  1782. mlp_drop_rate=mlp_drop_rate,
  1783. use_checkpoint=use_checkpoint,
  1784. ))
  1785. return nn.Sequential(OrderedDict(
  1786. blocks=nn.Sequential(*blocks,),
  1787. downsample=downsample,
  1788. ))
  1789. def forward(self, x: torch.Tensor):
  1790. x = self.patch_embed(x)
  1791. if self.pos_embed is not None:
  1792. pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
  1793. x = x + pos_embed
  1794. for layer in self.layers:
  1795. x = layer(x)
  1796. x = self.classifier(x)
  1797. return x
  1798. def flops(self, shape=(3, 224, 224), verbose=True):
  1799. # shape = self.__input_shape__[1:]
  1800. supported_ops={
  1801. "aten::silu": None, # as relu is in _IGNORED_OPS
  1802. "aten::neg": None, # as relu is in _IGNORED_OPS
  1803. "aten::exp": None, # as relu is in _IGNORED_OPS
  1804. "aten::flip": None, # as permute is in _IGNORED_OPS
  1805. # "prim::PythonOp.CrossScan": None,
  1806. # "prim::PythonOp.CrossMerge": None,
  1807. "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=verbose),
  1808. }
  1809. model = copy.deepcopy(self)
  1810. model.cuda().eval()
  1811. input = torch.randn((1, *shape), device=next(model.parameters()).device)
  1812. params = parameter_count(model)[""]
  1813. Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
  1814. del model, input
  1815. return sum(Gflops.values()) * 1e9
  1816. return f"params {params} GFLOPs {sum(Gflops.values())}"
  1817. # used to load ckpt from previous training code
  1818. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  1819. def check_name(src, state_dict: dict = state_dict, strict=False):
  1820. if strict:
  1821. if prefix + src in list(state_dict.keys()):
  1822. return True
  1823. else:
  1824. key = prefix + src
  1825. for k in list(state_dict.keys()):
  1826. if k.startswith(key):
  1827. return True
  1828. return False
  1829. def change_name(src, dst, state_dict: dict = state_dict, strict=False):
  1830. if strict:
  1831. if prefix + src in list(state_dict.keys()):
  1832. state_dict[prefix + dst] = state_dict[prefix + src]
  1833. state_dict.pop(prefix + src)
  1834. else:
  1835. key = prefix + src
  1836. for k in list(state_dict.keys()):
  1837. if k.startswith(key):
  1838. new_k = prefix + dst + k[len(key):]
  1839. state_dict[new_k] = state_dict[k]
  1840. state_dict.pop(k)
  1841. if check_name("pos_embed", strict=True):
  1842. srcEmb: torch.Tensor = state_dict[prefix + "pos_embed"]
  1843. state_dict[prefix + "pos_embed"] = F.interpolate(srcEmb.float(), size=self.pos_embed.shape[2:4], align_corners=False, mode="bicubic").to(srcEmb.device)
  1844. change_name("patch_embed.proj", "patch_embed.0")
  1845. change_name("patch_embed.norm", "patch_embed.2")
  1846. for i in range(100):
  1847. for j in range(100):
  1848. change_name(f"layers.{i}.blocks.{j}.ln_1", f"layers.{i}.blocks.{j}.norm")
  1849. change_name(f"layers.{i}.blocks.{j}.self_attention", f"layers.{i}.blocks.{j}.op")
  1850. change_name(f"layers.{i}.downsample.norm", f"layers.{i}.downsample.{1}")
  1851. change_name(f"layers.{i}.downsample.reduction", f"layers.{i}.downsample.{2}")
  1852. change_name("norm", "classifier.norm")
  1853. change_name("head", "classifier.head")
  1854. return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  1855. # compatible with openmmlab
  1856. class Backbone_VSSM(VSSM):
  1857. def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer="ln", **kwargs):
  1858. kwargs.update(norm_layer=norm_layer)
  1859. super().__init__(**kwargs)
  1860. self.channel_first = (norm_layer.lower() in ["ln2d"])
  1861. self.out_indices = out_indices
  1862. for i in out_indices:
  1863. layer = LayerNorm(self.dims[i], channel_first=self.channel_first)
  1864. layer_name = f'outnorm{i}'
  1865. self.add_module(layer_name, layer)
  1866. del self.classifier
  1867. self.load_pretrained(pretrained)
  1868. def load_pretrained(self, ckpt=None, key="model"):
  1869. if ckpt is None:
  1870. return
  1871. try:
  1872. _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
  1873. print(f"Successfully load ckpt {ckpt}")
  1874. incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
  1875. print(incompatibleKeys)
  1876. except Exception as e:
  1877. print(f"Failed loading checkpoint form {ckpt}: {e}")
  1878. def forward(self, x):
  1879. def layer_forward(l, x):
  1880. x = l.blocks(x)
  1881. y = l.downsample(x)
  1882. return x, y
  1883. x = self.patch_embed(x)
  1884. outs = []
  1885. for i, layer in enumerate(self.layers):
  1886. o, x = layer_forward(layer, x) # (B, H, W, C)
  1887. if i in self.out_indices:
  1888. norm_layer = getattr(self, f'outnorm{i}')
  1889. out = norm_layer(o)
  1890. if not self.channel_first:
  1891. out = out.permute(0, 3, 1, 2)
  1892. outs.append(out.contiguous())
  1893. if len(self.out_indices) == 0:
  1894. return x
  1895. return outs
  1896. ##########################################################
  1897. # main.py
  1898. ##########################################################
  1899. from timm.models import register_model
  1900. def load_checkpoint(path="", key="model"):
  1901. if path.startswith('https'):
  1902. checkpoint = torch.hub.load_state_dict_from_url(
  1903. path, map_location='cpu', check_hash=True)
  1904. else:
  1905. checkpoint = torch.load(path, map_location='cpu')
  1906. return checkpoint[key]
  1907. @register_model
  1908. def vmamba(**kwargs):
  1909. return VSSM(**kwargs)
  1910. @register_model
  1911. def vanilla_vmamba_tiny(pretrained=False, **kwargs):
  1912. model = VSSM(
  1913. depths=[2, 2, 9, 2], dims=96, drop_path_rate=0.2,
  1914. patch_size=4, in_chans=3, num_classes=1000,
  1915. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1916. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1917. ssm_init="v0", forward_type="v0",
  1918. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1919. patch_norm=True, norm_layer="ln",
  1920. downsample_version="v1", patchembed_version="v1",
  1921. use_checkpoint=False, posembed=False, imgsize=224,
  1922. )
  1923. if pretrained:
  1924. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v0cls/vssmtiny_dp01_ckpt_epoch_292.pth"))
  1925. return model
  1926. @register_model
  1927. def vanilla_vmamba_small(pretrained=False, **kwargs):
  1928. model = VSSM(
  1929. depths=[2, 2, 27, 2], dims=96, drop_path_rate=0.3,
  1930. patch_size=4, in_chans=3, num_classes=1000,
  1931. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1932. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1933. ssm_init="v0", forward_type="v0",
  1934. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1935. patch_norm=True, norm_layer="ln",
  1936. downsample_version="v1", patchembed_version="v1",
  1937. use_checkpoint=False, posembed=False, imgsize=224,
  1938. )
  1939. if pretrained:
  1940. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v0cls/vssmsmall_dp03_ckpt_epoch_238.pth"))
  1941. return model
  1942. @register_model
  1943. def vanilla_vmamba_base(pretrained=False, **kwargs):
  1944. model = VSSM(
  1945. depths=[2, 2, 27, 2], dims=128, drop_path_rate=0.6,
  1946. patch_size=4, in_chans=3, num_classes=1000,
  1947. ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1948. ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0,
  1949. ssm_init="v0", forward_type="v0",
  1950. mlp_ratio=0.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1951. patch_norm=True, norm_layer="ln",
  1952. downsample_version="v1", patchembed_version="v1",
  1953. use_checkpoint=False, posembed=False, imgsize=224,
  1954. )
  1955. if pretrained:
  1956. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v0cls/vssmbase_dp06_ckpt_epoch_241.pth"))
  1957. return model
  1958. @register_model
  1959. def vmamba_tiny_s2l5(pretrained=False, channel_first=True, **kwargs):
  1960. model = VSSM(
  1961. depths=[2, 2, 5, 2], dims=96, drop_path_rate=0.2,
  1962. patch_size=4, in_chans=3, num_classes=1000,
  1963. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1964. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1965. ssm_init="v0", forward_type="v05_noz",
  1966. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1967. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1968. downsample_version="v3", patchembed_version="v2",
  1969. use_checkpoint=False, posembed=False, imgsize=224,
  1970. )
  1971. if pretrained:
  1972. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm_tiny_0230_ckpt_epoch_262.pth"))
  1973. return model
  1974. @register_model
  1975. def vmamba_small_s2l15(pretrained=False, channel_first=True, **kwargs):
  1976. model = VSSM(
  1977. depths=[2, 2, 15, 2], dims=96, drop_path_rate=0.3,
  1978. patch_size=4, in_chans=3, num_classes=1000,
  1979. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1980. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1981. ssm_init="v0", forward_type="v05_noz",
  1982. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1983. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  1984. downsample_version="v3", patchembed_version="v2",
  1985. use_checkpoint=False, posembed=False, imgsize=224,
  1986. )
  1987. if pretrained:
  1988. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm_small_0229_ckpt_epoch_222.pth"))
  1989. return model
  1990. @register_model
  1991. def vmamba_base_s2l15(pretrained=False, channel_first=True, **kwargs):
  1992. model = VSSM(
  1993. depths=[2, 2, 15, 2], dims=128, drop_path_rate=0.6,
  1994. patch_size=4, in_chans=3, num_classes=1000,
  1995. ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  1996. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  1997. ssm_init="v0", forward_type="v05_noz",
  1998. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  1999. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  2000. downsample_version="v3", patchembed_version="v2",
  2001. use_checkpoint=False, posembed=False, imgsize=224,
  2002. )
  2003. if pretrained:
  2004. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm_base_0229_ckpt_epoch_237.pth"))
  2005. return model
  2006. @register_model
  2007. def vmamba_tiny_s1l8(pretrained=False, channel_first=True, **kwargs):
  2008. model = VSSM(
  2009. depths=[2, 2, 8, 2], dims=96, drop_path_rate=0.2,
  2010. patch_size=4, in_chans=3, num_classes=1000,
  2011. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  2012. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  2013. ssm_init="v0", forward_type="v05_noz",
  2014. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  2015. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  2016. downsample_version="v3", patchembed_version="v2",
  2017. use_checkpoint=False, posembed=False, imgsize=224,
  2018. )
  2019. if pretrained:
  2020. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm1_tiny_0230s_ckpt_epoch_264.pth"))
  2021. return model
  2022. @register_model
  2023. def vmamba_small_s1l20(pretrained=False, channel_first=True, **kwargs):
  2024. model = VSSM(
  2025. depths=[2, 2, 20, 2], dims=96, drop_path_rate=0.3,
  2026. patch_size=4, in_chans=3, num_classes=1000,
  2027. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  2028. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  2029. ssm_init="v0", forward_type="v05_noz",
  2030. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  2031. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  2032. downsample_version="v3", patchembed_version="v2",
  2033. use_checkpoint=False, posembed=False, imgsize=224,
  2034. )
  2035. if pretrained:
  2036. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm1_small_0229s_ckpt_epoch_240.pth"))
  2037. return model
  2038. @register_model
  2039. def vmamba_base_s1l20(pretrained=False, channel_first=True, **kwargs):
  2040. model = VSSM(
  2041. depths=[2, 2, 20, 2], dims=128, drop_path_rate=0.5,
  2042. patch_size=4, in_chans=3, num_classes=1000,
  2043. ssm_d_state=1, ssm_ratio=1.0, ssm_dt_rank="auto", ssm_act_layer="silu",
  2044. ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0,
  2045. ssm_init="v0", forward_type="v05_noz",
  2046. mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False,
  2047. patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"),
  2048. downsample_version="v3", patchembed_version="v2",
  2049. use_checkpoint=False, posembed=False, imgsize=224,
  2050. )
  2051. if pretrained:
  2052. model.load_state_dict(load_checkpoint("https://github.com/MzeroMiko/VMamba/releases/download/%23v2cls/vssm1_base_0229s_ckpt_epoch_225.pth"))
  2053. return model
  2054. def get_val_loader(batch_size=64, root="./val", img_size=224, sequential=True, num_workers=0):
  2055. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  2056. from torchvision import transforms, datasets
  2057. size = int((256 / 224) * img_size)
  2058. transform = transforms.Compose([
  2059. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  2060. transforms.CenterCrop((img_size, img_size)),
  2061. transforms.ToTensor(),
  2062. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  2063. ])
  2064. dataset = datasets.ImageFolder(root, transform=transform)
  2065. if sequential:
  2066. sampler = torch.utils.data.SequentialSampler(dataset)
  2067. else:
  2068. sampler = torch.utils.data.DistributedSampler(dataset)
  2069. data_loader = torch.utils.data.DataLoader(
  2070. dataset, sampler=sampler,
  2071. batch_size=batch_size,
  2072. shuffle=False,
  2073. num_workers=num_workers,
  2074. pin_memory=True,
  2075. drop_last=False
  2076. )
  2077. return data_loader
  2078. @torch.no_grad()
  2079. def validate(data_loader, model, amp_enable=True, print_freq=100000):
  2080. from timm.utils import accuracy, AverageMeter
  2081. criterion = nn.CrossEntropyLoss()
  2082. model.cuda()
  2083. model.eval()
  2084. batch_time = AverageMeter()
  2085. loss_meter = AverageMeter()
  2086. acc1_meter = AverageMeter()
  2087. acc5_meter = AverageMeter()
  2088. end = time.time()
  2089. for idx, (images, target) in enumerate(data_loader):
  2090. images = images.cuda(non_blocking=True)
  2091. target = target.cuda(non_blocking=True)
  2092. # compute output
  2093. with torch.cuda.amp.autocast(enabled=amp_enable):
  2094. output = model(images)
  2095. # measure accuracy and record loss
  2096. loss = criterion(output, target)
  2097. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  2098. # acc1 = reduce_tensor(acc1)
  2099. # acc5 = reduce_tensor(acc5)
  2100. # loss = reduce_tensor(loss)
  2101. loss_meter.update(loss.item(), target.size(0))
  2102. acc1_meter.update(acc1.item(), target.size(0))
  2103. acc5_meter.update(acc5.item(), target.size(0))
  2104. # measure elapsed time
  2105. batch_time.update(time.time() - end)
  2106. end = time.time()
  2107. if (idx + 1) % print_freq == 0:
  2108. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  2109. print(
  2110. f'Test: [{idx}/{len(data_loader)}]\t'
  2111. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  2112. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  2113. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  2114. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  2115. f'Mem {memory_used:.0f}MB')
  2116. # print(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  2117. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  2118. @torch.no_grad()
  2119. def throughput(data_loader, model):
  2120. model.cuda()
  2121. model.eval()
  2122. for idx, (images, _) in enumerate(data_loader):
  2123. images = images.cuda(non_blocking=True)
  2124. batch_size = images.shape[0]
  2125. for i in range(50):
  2126. model(images)
  2127. torch.cuda.synchronize()
  2128. print(f"throughput averaged with 30 times")
  2129. tic1 = time.time()
  2130. for i in range(30):
  2131. model(images)
  2132. torch.cuda.synchronize()
  2133. tic2 = time.time()
  2134. print(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  2135. return
  2136. def do_validate(name="vmamba_tiny_s1l8", data="/media/memfs/ImageNet_ILSVRC2012/val"):
  2137. from timm import create_model
  2138. if True:
  2139. torch.backends.cudnn.enabled = True
  2140. torch.backends.cudnn.benchmark = True
  2141. torch.backends.cudnn.deterministic = True
  2142. data_loader_val = get_val_loader(batch_size=64, root=data, num_workers=4)
  2143. model = create_model(name, pretrained=True)
  2144. acc1_ema, acc5_ema, loss_ema = validate(data_loader_val, model)
  2145. print(acc1_ema, acc5_ema, loss_ema)
  2146. def do_throughput(name="vmamba_tiny_s1l8", data="/media/memfs/ImageNet_ILSVRC2012/val"):
  2147. from timm import create_model
  2148. if True:
  2149. torch.backends.cudnn.enabled = True
  2150. torch.backends.cudnn.benchmark = True
  2151. torch.backends.cudnn.deterministic = True
  2152. data_loader_val = get_val_loader(batch_size=128, root=data, num_workers=4)
  2153. model = create_model(name, pretrained=True)
  2154. throughput(data_loader_val, model)
  2155. if __name__ == "__main__":
  2156. # do_validate("vanilla_vmamba_tiny") # 82.17106973558698 96.03223806724185 0.7879069638634182
  2157. # do_validate("vanilla_vmamba_small") # 83.4609923402307 96.47021178881855 0.7160880894021359
  2158. # do_validate("vanilla_vmamba_base") # 83.72897626157689 96.62420254754197 0.6968230148378597
  2159. # do_validate("vmamba_tiny_s2l5") # 82.48905065741832 95.99624022634936 0.7805328359985901
  2160. # do_validate("vmamba_small_s2l15") # 83.64898106090746 96.59420434667109 0.7185911423439594
  2161. # do_validate("vmamba_base_s2l15") # 83.87896726211686 96.71219726709586 0.7198247987933224
  2162. # do_validate("vmamba_tiny_s1l8") # 83.87896726211686 96.71219726709586 0.7198247987933224
  2163. # do_validate("vmamba_small_s1l20") # 83.33899965941008 96.42621442606632 nan
  2164. # do_validate("vmamba_base_s1l20") # 83.79097254317328 96.61420314781112 0.7243299191111033
  2165. # do_throughput("vanilla_vmamba_tiny")
  2166. # do_throughput("vanilla_vmamba_small")
  2167. # do_throughput("vanilla_vmamba_base")
  2168. # do_throughput("vmamba_tiny_s2l5")
  2169. # do_throughput("vmamba_small_s2l15")
  2170. # do_throughput("vmamba_base_s2l15")
  2171. do_throughput("vmamba_tiny_s1l8")
  2172. # do_throughput("vmamba_small_s1l20")
  2173. # do_throughput("vmamba_base_s1l20")