convnext_timm.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the MIT license
  4. """ConvNext TIMM version with S4ND integration.
  5. Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  6. Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
  7. Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
  8. """
  9. from collections import OrderedDict
  10. from functools import partial
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import numpy as np
  15. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  16. from timm.models.fx_features import register_notrace_module
  17. # from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq
  18. from timm.models.helpers import named_apply, build_model_with_cfg
  19. from timm.models.layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
  20. from timm.models.registry import register_model
  21. import copy
  22. from einops import rearrange, repeat
  23. from einops.layers.torch import Rearrange
  24. from omegaconf import OmegaConf
  25. # S4 imports
  26. import src.utils as utils
  27. import src.utils.registry as registry
  28. from src.models.nn import TransposedLinear
  29. __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
  30. def _cfg(url='', **kwargs):
  31. return {
  32. 'url': url,
  33. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  34. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  35. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  36. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  37. **kwargs
  38. }
  39. default_cfgs = dict(
  40. convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
  41. convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
  42. convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
  43. convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
  44. convnext_nano_hnf=_cfg(url=''),
  45. convnext_tiny_hnf=_cfg(
  46. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
  47. crop_pct=0.95),
  48. convnext_tiny_in22ft1k=_cfg(
  49. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
  50. convnext_small_in22ft1k=_cfg(
  51. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
  52. convnext_base_in22ft1k=_cfg(
  53. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
  54. convnext_large_in22ft1k=_cfg(
  55. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),
  56. convnext_xlarge_in22ft1k=_cfg(
  57. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
  58. convnext_tiny_384_in22ft1k=_cfg(
  59. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
  60. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  61. convnext_small_384_in22ft1k=_cfg(
  62. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
  63. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  64. convnext_base_384_in22ft1k=_cfg(
  65. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
  66. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  67. convnext_large_384_in22ft1k=_cfg(
  68. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
  69. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  70. convnext_xlarge_384_in22ft1k=_cfg(
  71. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
  72. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  73. convnext_tiny_in22k=_cfg(
  74. url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
  75. convnext_small_in22k=_cfg(
  76. url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
  77. convnext_base_in22k=_cfg(
  78. url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
  79. convnext_large_in22k=_cfg(
  80. url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
  81. convnext_xlarge_in22k=_cfg(
  82. url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
  83. )
  84. def _is_contiguous(tensor: torch.Tensor) -> bool:
  85. # jit is oh so lovely :/
  86. # if torch.jit.is_tracing():
  87. # return True
  88. if torch.jit.is_scripting():
  89. return tensor.is_contiguous()
  90. else:
  91. return tensor.is_contiguous(memory_format=torch.contiguous_format)
  92. def get_num_layer_for_convnext(var_name, variant='tiny'):
  93. """
  94. Divide [3, 3, 27, 3] layers into 12 groups; each group is three
  95. consecutive blocks, including possible neighboring downsample layers;
  96. adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
  97. """
  98. num_max_layer = 12
  99. if "stem" in var_name:
  100. return 0
  101. # note: moved norm_layer outside of downsample module
  102. elif "downsample" in var_name or "norm_layer" in var_name:
  103. stage_id = int(var_name.split('.')[2])
  104. if stage_id == 0:
  105. layer_id = 0
  106. elif stage_id == 1 or stage_id == 2:
  107. layer_id = stage_id + 1
  108. elif stage_id == 3:
  109. layer_id = 12
  110. return layer_id
  111. elif "stages" in var_name:
  112. stage_id = int(var_name.split('.')[2])
  113. block_id = int(var_name.split('.')[4])
  114. if stage_id == 0 or stage_id == 1:
  115. layer_id = stage_id + 1
  116. elif stage_id == 2:
  117. if variant == 'tiny':
  118. layer_id = 3 + block_id
  119. else:
  120. layer_id = 3 + block_id // 3
  121. elif stage_id == 3:
  122. layer_id = 12
  123. return layer_id
  124. else:
  125. return num_max_layer + 1
  126. def get_num_layer_for_convnext_tiny(var_name):
  127. return get_num_layer_for_convnext(var_name, 'tiny')
  128. @register_notrace_module
  129. class DropoutNd(nn.Module):
  130. def __init__(self, p: float = 0.5, tie=True):
  131. """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
  132. For some reason tie=False is dog slow, prob something wrong with torch.distribution
  133. """
  134. super().__init__()
  135. if p < 0 or p >= 1:
  136. raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
  137. self.p = p
  138. self.tie = tie
  139. self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
  140. def forward(self, X):
  141. """ X: (batch, dim, lengths...) """
  142. if self.training:
  143. # binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
  144. mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
  145. # mask = self.binomial.sample(mask_shape)
  146. mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
  147. return X * mask * (1.0/(1-self.p))
  148. return X
  149. @register_notrace_module
  150. class LayerNorm2d(nn.LayerNorm):
  151. r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
  152. """
  153. def __init__(self, normalized_shape, eps=1e-6):
  154. super().__init__(normalized_shape, eps=eps)
  155. def forward(self, x) -> torch.Tensor:
  156. if _is_contiguous(x):
  157. return F.layer_norm(
  158. x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
  159. else:
  160. s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
  161. x = (x - u) * torch.rsqrt(s + self.eps)
  162. x = x * self.weight[:, None, None] + self.bias[:, None, None]
  163. return x
  164. @register_notrace_module
  165. class LayerNorm3d(nn.LayerNorm):
  166. r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, L, H, W).
  167. """
  168. def __init__(self, normalized_shape, eps=1e-6):
  169. super().__init__(normalized_shape, eps=eps)
  170. def forward(self, x) -> torch.Tensor:
  171. if _is_contiguous(x):
  172. return F.layer_norm(
  173. x.permute(0, 2, 3, 4, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 4, 1, 2, 3)
  174. else:
  175. s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
  176. x = (x - u) * torch.rsqrt(s + self.eps)
  177. x = x * self.weight[:, None, None, None] + self.bias[:, None, None, None]
  178. return x
  179. @register_notrace_module
  180. class TransposedLN(nn.Module):
  181. def __init__(self, d, scalar=True):
  182. super().__init__()
  183. self.m = nn.Parameter(torch.zeros(1))
  184. self.s = nn.Parameter(torch.ones(1))
  185. setattr(self.m, "_optim", {"weight_decay": 0.0})
  186. setattr(self.s, "_optim", {"weight_decay": 0.0})
  187. def forward(self, x):
  188. s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True)
  189. y = (self.s/s) * (x-m+self.m)
  190. return y
  191. class Conv2dWrapper(nn.Module):
  192. """
  193. Light wrapper used to just absorb the resolution flag (like s4's conv layer)
  194. """
  195. def __init__(self, dim_in, dim_out, **kwargs):
  196. super().__init__()
  197. self.conv = nn.Conv2d(dim_in, dim_out, **kwargs)
  198. def forward(self, x, resolution=None):
  199. return self.conv(x)
  200. class S4DownSample(nn.Module):
  201. """ S4 conv block with downsampling using avg pool
  202. Args:
  203. downsample_layer (dict): config for creating s4 layer
  204. in_ch (int): num input channels
  205. out_ch (int): num output channels
  206. stride (int): downsample factor in avg pool
  207. """
  208. def __init__(self, downsample_layer, in_ch, out_ch, stride=1, activate=False, glu=False, pool3d=False):
  209. super().__init__()
  210. # create s4
  211. self.s4conv = utils.instantiate(registry.layer, downsample_layer, in_ch)
  212. self.act = nn.GELU() if activate else nn.Identity()
  213. if pool3d:
  214. self.avgpool = nn.AvgPool3d(kernel_size=stride, stride=stride)
  215. else:
  216. self.avgpool = nn.AvgPool2d(kernel_size=stride, stride=stride)
  217. self.glu = glu
  218. d_out = 2*out_ch if self.glu else out_ch
  219. self.fc = TransposedLinear(in_ch, d_out)
  220. def forward(self, x, resolution=1):
  221. x = self.s4conv(x, resolution)
  222. x = self.act(x)
  223. x = self.avgpool(x)
  224. x = self.fc(x)
  225. if self.glu:
  226. x = F.glu(x, dim=1)
  227. return x
  228. class ConvNeXtBlock(nn.Module):
  229. """ ConvNeXt Block
  230. # previous convnext notes:
  231. There are two equivalent implementations:
  232. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  233. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  234. Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
  235. choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
  236. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
  237. # two options for convs are:
  238. - conv2d, depthwise (original)
  239. - s4nd, used if a layer config passed
  240. Args:
  241. dim (int): Number of input channels.
  242. drop_path (float): Stochastic depth rate. Default: 0.0
  243. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  244. layer (config/dict): config for s4 layer
  245. """
  246. def __init__(self,
  247. dim,
  248. drop_path=0.,
  249. ls_init_value=1e-6,
  250. conv_mlp=False,
  251. mlp_ratio=4,
  252. norm_layer=None,
  253. layer=None,
  254. ):
  255. super().__init__()
  256. assert norm_layer is not None
  257. mlp_layer = ConvMlp if conv_mlp else Mlp
  258. self.use_conv_mlp = conv_mlp
  259. # Depthwise conv
  260. if layer is None:
  261. self.conv_dw = Conv2dWrapper(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
  262. else:
  263. self.conv_dw = utils.instantiate(registry.layer, layer, dim)
  264. self.norm = norm_layer(dim)
  265. self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
  266. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
  267. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  268. def forward(self, x, resolution=1):
  269. shortcut = x
  270. x = self.conv_dw(x, resolution)
  271. if self.use_conv_mlp:
  272. x = self.norm(x)
  273. x = self.mlp(x)
  274. else:
  275. x = x.permute(0, 2, 3, 1)
  276. x = self.norm(x)
  277. x = self.mlp(x)
  278. x = x.permute(0, 3, 1, 2)
  279. if self.gamma is not None:
  280. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  281. x = self.drop_path(x) + shortcut
  282. return x
  283. class Stem(nn.Module):
  284. def __init__(self,
  285. stem_type='patch', # regular convnext
  286. in_ch=3,
  287. out_ch=64,
  288. img_size=None,
  289. patch_size=4,
  290. stride=4,
  291. stem_channels=32,
  292. stem_layer=None,
  293. stem_l_max=None,
  294. downsample_act=False,
  295. downsample_glu=False,
  296. norm_layer=None,
  297. ):
  298. super().__init__()
  299. self.stem_type = stem_type
  300. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  301. self.pre_stem = None
  302. self.post_stem = None
  303. if stem_type == 'patch':
  304. print("stem type: ", 'patch')
  305. self.stem = nn.Sequential(
  306. nn.Conv2d(in_ch, out_ch, kernel_size=patch_size, stride=patch_size),
  307. norm_layer(out_ch)
  308. )
  309. elif stem_type == 'depthwise_patch':
  310. print("stem type: ", 'depthwise_patch')
  311. self.stem = nn.Sequential(
  312. nn.Conv2d(in_ch, stem_channels, kernel_size=1, stride=1, padding=0),
  313. nn.Conv2d(stem_channels, stem_channels, kernel_size=patch_size, stride=1, padding='same', groups=stem_channels),
  314. nn.AvgPool2d(kernel_size=patch_size, stride=patch_size),
  315. TransposedLinear(stem_channels, 2*out_ch),
  316. nn.GLU(dim=1),
  317. norm_layer(out_ch),
  318. )
  319. elif stem_type == 'new_patch':
  320. print("stem type: ", 'new_patch')
  321. self.stem = nn.Sequential(
  322. nn.Conv2d(in_ch, stem_channels, kernel_size=patch_size, stride=1, padding='same'),
  323. nn.AvgPool2d(kernel_size=patch_size, stride=patch_size),
  324. TransposedLinear(stem_channels, 2*out_ch),
  325. nn.GLU(dim=1),
  326. norm_layer(out_ch),
  327. )
  328. elif stem_type == 'new_s4nd_patch':
  329. print("stem type: ", 'new_s4nd_patch')
  330. stem_layer_copy = copy.deepcopy(stem_layer)
  331. assert stem_l_max is not None, "need to provide a stem_l_max to use stem=new_s4nd_patch"
  332. stem_layer_copy["l_max"] = stem_l_max
  333. self.pre_stem = nn.Identity()
  334. self.stem = utils.instantiate(registry.layer, stem_layer_copy, in_ch, out_channels=stem_channels)
  335. self.post_stem = nn.Sequential(
  336. nn.AvgPool2d(kernel_size=patch_size, stride=patch_size),
  337. TransposedLinear(stem_channels, 2*out_ch),
  338. nn.GLU(dim=1),
  339. norm_layer(out_ch)
  340. )
  341. elif stem_type == 's4nd_patch':
  342. print("stem type: ", "s4nd_patch")
  343. stem_layer_copy = copy.deepcopy(stem_layer)
  344. stem_layer_copy["l_max"] = img_size
  345. self.pre_stem = nn.Conv2d(in_ch, stem_channels, kernel_size=1, stride=1, padding=0)
  346. # s4 + norm + avg pool + linear
  347. self.stem = S4DownSample(stem_layer_copy, stem_channels, out_ch, stride=patch_size, activate=downsample_act, glu=downsample_glu)
  348. self.post_stem = norm_layer(out_ch)
  349. elif stem_type == 's4nd':
  350. # mix of conv2d + s4
  351. print("stem type: ", 's4nd')
  352. stem_layer_copy = copy.deepcopy(stem_layer)
  353. stem_layer_copy["l_max"] = img_size
  354. # s4_downsample = nn.Sequential(
  355. # utils.instantiate(registry.layer, stage_layer_copy, stem_channels),
  356. # nn.AvgPool2d(kernel_size=2, stride=2),
  357. # TransposedLinear(stem_channels, 64),
  358. # )
  359. s4_downsample = S4DownSample(stem_layer_copy, stem_channels, 64, stride=2, activate=downsample_act, glu=downsample_glu)
  360. self.pre_stem = nn.Sequential(
  361. nn.Conv2d(in_ch, stem_channels, kernel_size=1, stride=1, padding=0),
  362. norm_layer(stem_channels),
  363. nn.GELU()
  364. )
  365. self.stem = s4_downsample
  366. self.post_stem = nn.Identity()
  367. # regular strided conv downsample
  368. elif stem_type == 'default':
  369. print("stem type: DEFAULT. Make sure this is what you want.")
  370. self.stem = nn.Sequential(
  371. nn.Conv2d(in_ch, 32, kernel_size=3, stride=2, padding=1),
  372. norm_layer(32),
  373. nn.GELU(),
  374. nn.Conv2d(32, 64, kernel_size=3, padding=1),
  375. )
  376. else:
  377. raise NotImplementedError("provide a valid stem type!")
  378. def forward(self, x, resolution):
  379. # if using s4nd layer, need to pass resolution
  380. if self.stem_type in ['s4nd', 's4nd_patch', 'new_s4nd_patch']:
  381. x = self.pre_stem(x)
  382. x = self.stem(x, resolution)
  383. x = self.post_stem(x)
  384. else:
  385. x = self.stem(x)
  386. return x
  387. class ConvNeXtStage(nn.Module):
  388. """
  389. Will create a stage, made up of downsampling and conv blocks.
  390. There are 2 choices for each of these:
  391. downsampling: s4 or strided conv (original)
  392. conv stage: s4 or conv2d (original)
  393. """
  394. def __init__(
  395. self,
  396. in_chs,
  397. out_chs,
  398. img_size=None,
  399. stride=2,
  400. depth=2,
  401. dp_rates=None,
  402. ls_init_value=1.0,
  403. conv_mlp=False,
  404. norm_layer=None,
  405. cl_norm_layer=None,
  406. # cross_stage=False,
  407. stage_layer=None, # config
  408. # downsample_layer=None,
  409. downsample_type=None,
  410. downsample_act=False,
  411. downsample_glu=False,
  412. ):
  413. super().__init__()
  414. self.grad_checkpointing = False
  415. self.downsampling = False
  416. # 2 options to downsample
  417. if in_chs != out_chs or stride > 1:
  418. self.downsampling = True
  419. # s4 type copies config from corresponding stage layer
  420. if downsample_type == 's4nd':
  421. print("s4nd downsample")
  422. downsample_layer_copy = copy.deepcopy(stage_layer)
  423. downsample_layer_copy["l_max"] = img_size # always need to update curr l_max
  424. self.norm_layer = norm_layer(in_chs)
  425. # mimics strided conv but w/s4
  426. self.downsample = S4DownSample(downsample_layer_copy, in_chs, out_chs, stride=stride, activate=downsample_act, glu=downsample_glu)
  427. # strided conv
  428. else:
  429. print("strided conv downsample")
  430. self.norm_layer = norm_layer(in_chs)
  431. self.downsample = Conv2dWrapper(in_chs, out_chs, kernel_size=stride, stride=stride)
  432. # else:
  433. # self.norm_layer = nn.Identity()
  434. # self.downsample = nn.Identity()
  435. if stage_layer is not None:
  436. stage_layer["l_max"] = [x // stride for x in img_size]
  437. dp_rates = dp_rates or [0.] * depth
  438. self.blocks = nn.ModuleList()
  439. for j in range(depth):
  440. self.blocks.append(
  441. ConvNeXtBlock(
  442. dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
  443. norm_layer=norm_layer if conv_mlp else cl_norm_layer, layer=stage_layer)
  444. )
  445. def forward(self, x, resolution=1):
  446. if self.downsampling:
  447. x = self.norm_layer(x)
  448. x = self.downsample(x, resolution)
  449. for block in self.blocks:
  450. x = block(x, resolution)
  451. # not downsampling we just don't create a downsample layer, since before Identity can't accept pass through args
  452. else:
  453. for block in self.blocks:
  454. x = block(x, resolution)
  455. return x
  456. class ConvNeXt(nn.Module):
  457. r""" ConvNeXt
  458. A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  459. Args:
  460. in_chans (int): Number of input image channels. Default: 3
  461. num_classes (int): Number of classes for classification head. Default: 1000
  462. depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
  463. dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
  464. drop_head (float): Head dropout rate
  465. drop_path_rate (float): Stochastic depth rate. Default: 0.
  466. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  467. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
  468. """
  469. def __init__(
  470. self,
  471. in_chans=3,
  472. num_classes=1000,
  473. global_pool='avg',
  474. output_stride=32,
  475. patch_size=4,
  476. stem_channels=8,
  477. depths=(3, 3, 9, 3),
  478. dims=(96, 192, 384, 768),
  479. ls_init_value=1e-6,
  480. conv_mlp=False, # whether to transpose channels to last dim inside MLP
  481. stem_type='patch', # supports `s4nd` + avg pool
  482. stem_l_max=None, # len of l_max in stem (if using s4)
  483. downsample_type='patch', # supports `s4nd` + avg pool
  484. downsample_act=False,
  485. downsample_glu=False,
  486. head_init_scale=1.,
  487. head_norm_first=False,
  488. norm_layer=None,
  489. custom_ln=False,
  490. drop_head=0.,
  491. drop_path_rate=0.,
  492. layer=None, # Shared config dictionary for the core layer
  493. stem_layer=None,
  494. stage_layers=None,
  495. img_size=None,
  496. # **kwargs, # catch all
  497. ):
  498. super().__init__()
  499. assert output_stride == 32
  500. if norm_layer is None:
  501. if custom_ln:
  502. norm_layer = TransposedLN
  503. else:
  504. norm_layer = partial(LayerNorm2d, eps=1e-6)
  505. cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
  506. else:
  507. assert conv_mlp,\
  508. 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
  509. cl_norm_layer = norm_layer
  510. self.num_classes = num_classes
  511. self.drop_head = drop_head
  512. self.feature_info = []
  513. self._img_sizes = [img_size]
  514. # Broadcast dictionaries
  515. if layer is not None:
  516. stage_layers = [OmegaConf.merge(layer, s) for s in stage_layers]
  517. stem_layer = OmegaConf.merge(layer, stem_layer)
  518. # instantiate stem
  519. self.stem = Stem(
  520. stem_type=stem_type,
  521. in_ch=in_chans,
  522. out_ch=dims[0],
  523. img_size=img_size,
  524. patch_size=patch_size,
  525. stride=patch_size,
  526. stem_channels=stem_channels,
  527. stem_layer=stem_layer,
  528. stem_l_max=stem_l_max,
  529. downsample_act=downsample_act,
  530. downsample_glu=downsample_glu,
  531. norm_layer=norm_layer,
  532. )
  533. if stem_type == 's4nd' or stem_type == 'default':
  534. stem_stride = 2
  535. prev_chs = 64
  536. else:
  537. stem_stride = patch_size
  538. prev_chs = dims[0]
  539. curr_img_size = [x // stem_stride for x in img_size]
  540. self._img_sizes.append(curr_img_size)
  541. self.stages = nn.ModuleList()
  542. dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
  543. # 4 feature resolution stages, each consisting of multiple residual blocks
  544. for i in range(4):
  545. # if stem downsampled by 4, then in stage 0, we don't downsample
  546. # if stem downsampled by 2, then in stage 0, we downsample by 2
  547. # all other stages we downsample by 2 no matter what
  548. stride = 1 if i==0 and stem_stride == 4 else 2 # stride 1 is no downsample (because already ds in stem)
  549. # print("stage {}, before downsampled img size {}, stride {}".format(i, curr_img_size, stride))
  550. out_chs = dims[i]
  551. self.stages.append(ConvNeXtStage(
  552. prev_chs,
  553. out_chs,
  554. img_size=curr_img_size,
  555. stride=stride,
  556. depth=depths[i],
  557. dp_rates=dp_rates[i],
  558. ls_init_value=ls_init_value,
  559. conv_mlp=conv_mlp,
  560. norm_layer=norm_layer,
  561. cl_norm_layer=cl_norm_layer,
  562. stage_layer=stage_layers[i],
  563. downsample_type=downsample_type,
  564. downsample_act=downsample_act,
  565. downsample_glu=downsample_glu,
  566. )
  567. )
  568. prev_chs = out_chs
  569. curr_img_size = [x // stride for x in curr_img_size] # update image size for next stage
  570. self._img_sizes.append(curr_img_size)
  571. # # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  572. # self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
  573. # self.stages = nn.Sequential(*stages)
  574. self.num_features = prev_chs
  575. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  576. # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
  577. self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
  578. self.head = nn.Sequential(OrderedDict([
  579. ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
  580. ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
  581. ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
  582. ('drop', nn.Dropout(self.drop_head)),
  583. ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
  584. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  585. @torch.jit.ignore
  586. def group_matcher(self, coarse=False):
  587. return dict(
  588. stem=r'^stem',
  589. blocks=r'^stages\.(\d+)' if coarse else [
  590. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  591. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  592. (r'^norm_pre', (99999,))
  593. ]
  594. )
  595. @torch.jit.ignore
  596. def set_grad_checkpointing(self, enable=True):
  597. for s in self.stages:
  598. s.grad_checkpointing = enable
  599. @torch.jit.ignore
  600. def get_classifier(self):
  601. return self.head.fc
  602. def reset_classifier(self, num_classes=0, global_pool=None):
  603. if global_pool is not None:
  604. self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  605. self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
  606. self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  607. def forward_features(self, x, resolution=1):
  608. x = self.stem(x, resolution)
  609. for stage in self.stages:
  610. x = stage(x, resolution)
  611. x = self.norm_pre(x)
  612. return x
  613. def forward_head(self, x, pre_logits: bool = False):
  614. # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
  615. x = self.head.global_pool(x)
  616. x = self.head.norm(x)
  617. x = self.head.flatten(x)
  618. x = self.head.drop(x)
  619. return x if pre_logits else self.head.fc(x)
  620. def forward(self, x, resolution=1, state=None):
  621. x = self.forward_features(x, resolution)
  622. x = self.forward_head(x)
  623. return x, None
  624. def _init_weights(module, name=None, head_init_scale=1.0):
  625. if isinstance(module, nn.Conv2d):
  626. trunc_normal_(module.weight, std=.02)
  627. nn.init.constant_(module.bias, 0)
  628. elif isinstance(module, nn.Linear):
  629. trunc_normal_(module.weight, std=.02)
  630. # check if has bias first
  631. if module.bias is not None:
  632. nn.init.constant_(module.bias, 0)
  633. if name and 'head.' in name:
  634. module.weight.data.mul_(head_init_scale)
  635. module.bias.data.mul_(head_init_scale)
  636. def checkpoint_filter_fn(state_dict, model):
  637. """ Remap FB checkpoints -> timm """
  638. if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
  639. return state_dict # non-FB checkpoint
  640. if 'model' in state_dict:
  641. state_dict = state_dict['model']
  642. out_dict = {}
  643. import re
  644. for k, v in state_dict.items():
  645. k = k.replace('downsample_layers.0.', 'stem.')
  646. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  647. k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
  648. k = k.replace('dwconv', 'conv_dw')
  649. k = k.replace('pwconv', 'mlp.fc')
  650. k = k.replace('head.', 'head.fc.')
  651. if k.startswith('norm.'):
  652. k = k.replace('norm', 'head.norm')
  653. if v.ndim == 2 and 'head' not in k:
  654. model_shape = model.state_dict()[k].shape
  655. v = v.reshape(model_shape)
  656. out_dict[k] = v
  657. return out_dict
  658. def _create_convnext(variant, pretrained=False, **kwargs):
  659. model = build_model_with_cfg(
  660. ConvNeXt, variant, pretrained,
  661. default_cfg=default_cfgs[variant],
  662. pretrained_filter_fn=checkpoint_filter_fn,
  663. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  664. **kwargs)
  665. return model
  666. # @register_model
  667. # def convnext_nano_hnf(pretrained=False, **kwargs):
  668. # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
  669. # model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
  670. # return model
  671. # @register_model
  672. # def convnext_tiny_hnf(pretrained=False, **kwargs):
  673. # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
  674. # model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
  675. # return model
  676. # @register_model
  677. # def convnext_tiny_hnfd(pretrained=False, **kwargs):
  678. # model_args = dict(
  679. # depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs)
  680. # model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
  681. # return model
  682. @register_model
  683. def convnext_micro(pretrained=False, **kwargs):
  684. model_args = dict(depths=(3, 3, 3, 3), dims=(64, 128, 256, 512), **kwargs)
  685. model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
  686. return model
  687. @register_model
  688. def convnext_tiny(pretrained=False, **kwargs):
  689. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
  690. model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
  691. return model
  692. @register_model
  693. def convnext_small(pretrained=False, **kwargs):
  694. model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
  695. model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
  696. return model
  697. @register_model
  698. def convnext_base(pretrained=False, **kwargs):
  699. model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  700. model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
  701. return model
  702. # @register_model
  703. # def convnext_large(pretrained=False, **kwargs):
  704. # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
  705. # model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
  706. # return model
  707. # @register_model
  708. # def convnext_tiny_in22ft1k(pretrained=False, **kwargs):
  709. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  710. # model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args)
  711. # return model
  712. # @register_model
  713. # def convnext_small_in22ft1k(pretrained=False, **kwargs):
  714. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  715. # model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args)
  716. # return model
  717. # @register_model
  718. # def convnext_base_in22ft1k(pretrained=False, **kwargs):
  719. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  720. # model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args)
  721. # return model
  722. # @register_model
  723. # def convnext_large_in22ft1k(pretrained=False, **kwargs):
  724. # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
  725. # model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args)
  726. # return model
  727. # @register_model
  728. # def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
  729. # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
  730. # model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args)
  731. # return model
  732. # @register_model
  733. # def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs):
  734. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  735. # model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args)
  736. # return model
  737. # @register_model
  738. # def convnext_small_384_in22ft1k(pretrained=False, **kwargs):
  739. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  740. # model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args)
  741. # return model
  742. # @register_model
  743. # def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
  744. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  745. # model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args)
  746. # return model
  747. # @register_model
  748. # def convnext_large_384_in22ft1k(pretrained=False, **kwargs):
  749. # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
  750. # model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args)
  751. # return model
  752. # @register_model
  753. # def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
  754. # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
  755. # model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args)
  756. # return model
  757. # @register_model
  758. # def convnext_tiny_in22k(pretrained=False, **kwargs):
  759. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  760. # model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args)
  761. # return model
  762. # @register_model
  763. # def convnext_small_in22k(pretrained=False, **kwargs):
  764. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  765. # model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args)
  766. # return model
  767. # @register_model
  768. # def convnext_base_in22k(pretrained=False, **kwargs):
  769. # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
  770. # model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args)
  771. # return model
  772. # @register_model
  773. # def convnext_large_in22k(pretrained=False, **kwargs):
  774. # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
  775. # model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args)
  776. # return model
  777. # @register_model
  778. # def convnext_xlarge_in22k(pretrained=False, **kwargs):
  779. # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
  780. # model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
  781. # return model
  782. class Conv3d(nn.Conv3d):
  783. def __init__(self, in_ch, out_ch, kernel_size, stride, padding=0, groups=1, factor=False):
  784. super().__init__(in_ch, out_ch, kernel_size, stride=stride, padding=padding, groups=groups)
  785. self.factor = factor
  786. self.in_ch=in_ch
  787. self.out_ch=out_ch
  788. self.kernel_size=[kernel_size] if isinstance(kernel_size, int) else kernel_size
  789. self.stride=stride
  790. self.padding=padding
  791. self.groups=groups
  792. if self.factor:
  793. self.weight = nn.Parameter(self.weight[:, :, 0, :, :]) # Subsample time dimension
  794. self.time_weight = nn.Parameter(self.weight.new_ones(self.kernel_size[0]) / self.kernel_size[0])
  795. else:
  796. pass
  797. def forward(self, x):
  798. if self.factor:
  799. weight = self.weight[:, :, None, :, :] * self.time_weight[:, None, None]
  800. y = F.conv3d(x, weight, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)
  801. else:
  802. y = super().forward(x)
  803. return y
  804. class Conv3dWrapper(nn.Module):
  805. """
  806. Light wrapper to make consistent with 2d version (allows for easier inflation).
  807. """
  808. def __init__(self, dim_in, dim_out, **kwargs):
  809. super().__init__()
  810. self.conv = Conv3d(dim_in, dim_out, **kwargs)
  811. def forward(self, x, resolution=None):
  812. return self.conv(x)
  813. class ConvNeXtBlock3D(nn.Module):
  814. """ ConvNeXt Block
  815. # previous convnext notes:
  816. There are two equivalent implementations:
  817. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  818. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  819. Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
  820. choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
  821. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
  822. # two options for convs are:
  823. - conv2d, depthwise (original)
  824. - s4nd, used if a layer config passed
  825. Args:
  826. dim (int): Number of input channels.
  827. drop_path (float): Stochastic depth rate. Default: 0.0
  828. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  829. layer (config/dict): config for s4 layer
  830. """
  831. def __init__(self,
  832. dim,
  833. drop_path=0.,
  834. drop_mlp=0.,
  835. ls_init_value=1e-6,
  836. conv_mlp=False,
  837. mlp_ratio=4,
  838. norm_layer=None,
  839. block_tempor_kernel=3,
  840. layer=None,
  841. factor_3d=False,
  842. ):
  843. super().__init__()
  844. assert norm_layer is not None
  845. # if not norm_layer:
  846. # norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
  847. mlp_layer = ConvMlp if conv_mlp else Mlp
  848. self.use_conv_mlp = conv_mlp
  849. # Depthwise conv
  850. if layer is None:
  851. tempor_padding = block_tempor_kernel // 2 # or 2
  852. # self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
  853. self.conv_dw = Conv3dWrapper(
  854. dim,
  855. dim,
  856. kernel_size=(block_tempor_kernel, 7, 7),
  857. padding=(tempor_padding, 3, 3),
  858. stride=(1, 1, 1),
  859. groups=dim,
  860. factor=factor_3d,
  861. ) # depthwise conv
  862. else:
  863. self.conv_dw = utils.instantiate(registry.layer, layer, dim)
  864. self.norm = norm_layer(dim)
  865. self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU, drop=drop_mlp)
  866. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
  867. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  868. def forward(self, x):
  869. shortcut = x
  870. x = self.conv_dw(x)
  871. if self.use_conv_mlp:
  872. x = self.norm(x)
  873. x = self.mlp(x)
  874. else:
  875. x = x.permute(0, 2, 3, 4, 1)
  876. x = self.norm(x)
  877. x = self.mlp(x)
  878. x = x.permute(0, 4, 1, 2, 3)
  879. if self.gamma is not None:
  880. x = x.mul(self.gamma.reshape(1, -1, 1, 1, 1))
  881. x = self.drop_path(x) + shortcut
  882. return x
  883. class ConvNeXtStage3D(nn.Module):
  884. """
  885. Will create a stage, made up of downsampling and conv blocks.
  886. There are 2 choices for each of these:
  887. downsampling: s4 or strided conv (original)
  888. conv stage: s4 or conv2d (original)
  889. """
  890. def __init__(
  891. self,
  892. in_chs,
  893. out_chs,
  894. video_size=None, # L, H, W
  895. stride=(2, 2, 2), # Strides for L, H, W
  896. depth=2,
  897. dp_rates=None,
  898. ls_init_value=1.0,
  899. conv_mlp=False,
  900. norm_layer=None,
  901. cl_norm_layer=None,
  902. stage_layer=None, # config
  903. block_tempor_kernel=3,
  904. downsample_type=None,
  905. downsample_act=False,
  906. downsample_glu=False,
  907. factor_3d=False,
  908. drop_mlp=0.,
  909. ):
  910. super().__init__()
  911. self.grad_checkpointing = False
  912. # 2 options to downsample
  913. if in_chs != out_chs or np.any(np.array(stride) > 1):
  914. # s4 type copies config from corresponding stage layer
  915. if downsample_type == 's4nd':
  916. print("s4nd downsample")
  917. downsample_layer_copy = copy.deepcopy(stage_layer)
  918. downsample_layer_copy["l_max"] = video_size # always need to update curr l_max
  919. self.norm_layer = norm_layer(in_chs)
  920. # mimics strided conv but w/s4
  921. self.downsample = S4DownSample(
  922. downsample_layer_copy,
  923. in_chs,
  924. out_chs,
  925. stride=stride,
  926. activate=downsample_act,
  927. glu=downsample_glu,
  928. pool3d=True,
  929. )
  930. # self.downsample = nn.Sequential(
  931. # norm_layer(in_chs),
  932. # S4DownSample(
  933. # downsample_layer_copy,
  934. # in_chs,
  935. # out_chs,
  936. # stride=stride,
  937. # activate=downsample_act,
  938. # glu=downsample_glu,
  939. # pool3d=True,
  940. # )
  941. # )
  942. # strided conv
  943. else:
  944. print("strided conv downsample")
  945. self.norm_layer = norm_layer(in_chs)
  946. self.downsample = Conv3dWrapper(in_chs, out_chs, kernel_size=stride, stride=stride, factor=factor_3d)
  947. # self.downsample = nn.Sequential(
  948. # norm_layer(in_chs),
  949. # Conv3d(in_chs, out_chs, kernel_size=stride, stride=stride, factor=factor_3d),
  950. # )
  951. else:
  952. self.norm_layer = nn.Identity()
  953. self.downsample = nn.Identity()
  954. if stage_layer is not None:
  955. stage_layer["l_max"] = [
  956. x // stride if isinstance(stride, int) else x // stride[i]
  957. for i, x in enumerate(video_size)
  958. ]
  959. dp_rates = dp_rates or [0.] * depth
  960. self.blocks = nn.Sequential(*[
  961. ConvNeXtBlock3D(
  962. dim=out_chs,
  963. drop_path=dp_rates[j],
  964. drop_mlp=drop_mlp,
  965. ls_init_value=ls_init_value,
  966. conv_mlp=conv_mlp,
  967. norm_layer=norm_layer if conv_mlp else cl_norm_layer,
  968. block_tempor_kernel=block_tempor_kernel,
  969. layer=stage_layer,
  970. factor_3d=factor_3d,
  971. )
  972. for j in range(depth)
  973. ])
  974. def forward(self, x):
  975. x = self.norm_layer(x)
  976. x = self.downsample(x)
  977. x = self.blocks(x)
  978. return x
  979. class Stem3d(nn.Module):
  980. def __init__(self,
  981. stem_type='patch', # supports `s4nd` + avg pool
  982. in_chans=3,
  983. spatial_patch_size=4,
  984. tempor_patch_size=4,
  985. stem_channels=8,
  986. dims=(96, 192, 384, 768),
  987. stem_l_max=None, # len of l_max in stem (if using s4)
  988. norm_layer=None,
  989. custom_ln=False,
  990. layer=None, # Shared config dictionary for the core layer
  991. stem_layer=None,
  992. factor_3d=False,
  993. ):
  994. super().__init__()
  995. self.stem_type = stem_type
  996. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  997. if stem_type == 'patch':
  998. print("stem type: ", 'patch')
  999. kernel_3d = [tempor_patch_size, spatial_patch_size, spatial_patch_size]
  1000. self.stem = nn.Sequential(
  1001. Conv3d(
  1002. in_chans,
  1003. dims[0],
  1004. kernel_size=kernel_3d,
  1005. stride=kernel_3d,
  1006. factor=factor_3d,
  1007. ),
  1008. norm_layer(dims[0]),
  1009. )
  1010. elif stem_type == 'new_s4nd_patch':
  1011. print("stem type: ", 'new_s4nd_patch')
  1012. stem_layer_copy = copy.deepcopy(stem_layer)
  1013. assert stem_l_max is not None, "need to provide a stem_l_max to use stem=new_s4nd_patch"
  1014. stem_layer_copy["l_max"] = stem_l_max
  1015. s4_ds = utils.instantiate(registry.layer, stem_layer_copy, in_chans, out_channels=stem_channels)
  1016. kernel_3d = [tempor_patch_size, spatial_patch_size, spatial_patch_size]
  1017. self.stem = nn.Sequential(
  1018. s4_ds,
  1019. nn.AvgPool3d(kernel_size=kernel_3d, stride=kernel_3d),
  1020. TransposedLinear(stem_channels, 2*dims[0]),
  1021. nn.GLU(dim=1),
  1022. norm_layer(dims[0]),
  1023. )
  1024. else:
  1025. raise NotImplementedError("provide a valid stem type!")
  1026. def forward(self, x, resolution=None):
  1027. # if using s4nd layer, need to pass resolution
  1028. if self.stem_type in ['new_s4nd_patch']:
  1029. x = self.stem(x, resolution)
  1030. else:
  1031. x = self.stem(x)
  1032. return x
  1033. class ConvNeXt3D(nn.Module):
  1034. r""" ConvNeXt
  1035. A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  1036. Args:
  1037. in_chans (int): Number of input image channels. Default: 3
  1038. num_classes (int): Number of classes for classification head. Default: 1000
  1039. depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
  1040. dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
  1041. drop_head (float): Head dropout rate
  1042. drop_path_rate (float): Stochastic depth rate. Default: 0.
  1043. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
  1044. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
  1045. """
  1046. def __init__(
  1047. self,
  1048. in_chans=3,
  1049. num_classes=1000,
  1050. # global_pool='avg',
  1051. spatial_patch_size=4,
  1052. tempor_patch_size=4,
  1053. output_spatial_stride=32,
  1054. # patch_size=(1, 4, 4),
  1055. stem_channels=8,
  1056. depths=(3, 3, 9, 3),
  1057. dims=(96, 192, 384, 768),
  1058. ls_init_value=1e-6,
  1059. conv_mlp=False, # whether to transpose channels to last dim inside MLP
  1060. stem_type='patch', # supports `s4nd` + avg pool
  1061. stem_l_max=None, # len of l_max in stem (if using s4)
  1062. downsample_type='patch', # supports `s4nd` + avg pool
  1063. downsample_act=False,
  1064. downsample_glu=False,
  1065. head_init_scale=1.,
  1066. head_norm_first=False,
  1067. norm_layer=None,
  1068. custom_ln=False,
  1069. drop_head=0.,
  1070. drop_path_rate=0.,
  1071. drop_mlp=0.,
  1072. layer=None, # Shared config dictionary for the core layer
  1073. stem_layer=None,
  1074. stage_layers=None,
  1075. video_size=None,
  1076. block_tempor_kernel=3, # only for non-s4 block
  1077. temporal_stage_strides=None,
  1078. factor_3d=False,
  1079. **kwargs, # catch all
  1080. ):
  1081. super().__init__()
  1082. assert output_spatial_stride == 32
  1083. if norm_layer is None:
  1084. if custom_ln:
  1085. norm_layer = TransposedLN
  1086. else:
  1087. norm_layer = partial(LayerNorm3d, eps=1e-6)
  1088. cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
  1089. else:
  1090. assert conv_mlp,\
  1091. 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
  1092. cl_norm_layer = norm_layer
  1093. self.num_classes = num_classes
  1094. self.drop_head = drop_head
  1095. self.feature_info = []
  1096. # Broadcast dictionaries
  1097. if layer is not None:
  1098. stage_layers = [OmegaConf.merge(layer, s) for s in stage_layers]
  1099. stem_layer = OmegaConf.merge(layer, stem_layer)
  1100. # instantiate stem here
  1101. self.stem = Stem3d(
  1102. stem_type=stem_type, # supports `s4nd` + avg pool
  1103. in_chans=in_chans,
  1104. spatial_patch_size=spatial_patch_size,
  1105. tempor_patch_size=tempor_patch_size,
  1106. stem_channels=stem_channels,
  1107. dims=dims,
  1108. stem_l_max=stem_l_max, # len of l_max in stem (if using s4)
  1109. norm_layer=norm_layer,
  1110. custom_ln=custom_ln,
  1111. layer=layer, # Shared config dictionary for the core layer
  1112. stem_layer=stem_layer,
  1113. factor_3d=factor_3d,
  1114. )
  1115. stem_stride = [tempor_patch_size, spatial_patch_size, spatial_patch_size]
  1116. prev_chs = dims[0]
  1117. # TODO: something else here?
  1118. curr_video_size = [
  1119. x // stem_stride if isinstance(stem_stride, int) else x // stem_stride[i]
  1120. for i, x in enumerate(video_size)
  1121. ]
  1122. self.stages = nn.Sequential()
  1123. dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
  1124. stages = []
  1125. # 4 feature resolution stages, each consisting of multiple residual blocks
  1126. for i in range(4):
  1127. # if stem downsampled by 4, then in stage 0, we don't downsample
  1128. # if stem downsampled by 2, then in stage 0, we downsample by 2
  1129. # all other stages we downsample by 2 no matter what
  1130. # might want to alter the
  1131. # temporal stride, we parse this specially
  1132. tempor_stride = temporal_stage_strides[i] if temporal_stage_strides is not None else 2
  1133. stride = [1, 1, 1] if i == 0 and np.any(np.array(stem_stride) >= 2) else [tempor_stride, 2, 2] # stride 1 is no downsample (because already ds in stem)
  1134. # print("stage {}, before downsampled img size {}, stride {}".format(i, curr_img_size, stride))
  1135. out_chs = dims[i]
  1136. stages.append(
  1137. ConvNeXtStage3D(
  1138. prev_chs,
  1139. out_chs,
  1140. video_size=curr_video_size,
  1141. stride=stride,
  1142. depth=depths[i],
  1143. dp_rates=dp_rates[i],
  1144. ls_init_value=ls_init_value,
  1145. conv_mlp=conv_mlp,
  1146. norm_layer=norm_layer,
  1147. cl_norm_layer=cl_norm_layer,
  1148. stage_layer=stage_layers[i],
  1149. block_tempor_kernel=block_tempor_kernel,
  1150. downsample_type=downsample_type,
  1151. downsample_act=downsample_act,
  1152. downsample_glu=downsample_glu,
  1153. factor_3d=factor_3d,
  1154. drop_mlp=drop_mlp,
  1155. )
  1156. )
  1157. prev_chs = out_chs
  1158. # update image size for next stage
  1159. curr_video_size = [
  1160. x // stride if isinstance(stride, int) else x // stride[i]
  1161. for i, x in enumerate(curr_video_size)
  1162. ]
  1163. # # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  1164. # self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
  1165. self.stages = nn.Sequential(*stages)
  1166. self.num_features = prev_chs
  1167. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  1168. # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
  1169. self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
  1170. self.head = nn.Sequential(OrderedDict([
  1171. ('global_pool', nn.AdaptiveAvgPool3d(1)),
  1172. ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
  1173. ('flatten', nn.Flatten(1)),
  1174. ('drop', nn.Dropout(self.drop_head)),
  1175. ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
  1176. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  1177. @torch.jit.ignore
  1178. def group_matcher(self, coarse=False):
  1179. return dict(
  1180. stem=r'^stem',
  1181. blocks=r'^stages\.(\d+)' if coarse else [
  1182. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  1183. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  1184. (r'^norm_pre', (99999,))
  1185. ]
  1186. )
  1187. @torch.jit.ignore
  1188. def set_grad_checkpointing(self, enable=True):
  1189. for s in self.stages:
  1190. s.grad_checkpointing = enable
  1191. @torch.jit.ignore
  1192. def get_classifier(self):
  1193. return self.head.fc
  1194. def reset_classifier(self, num_classes=0, **kwargs):
  1195. if global_pool is not None:
  1196. self.head.global_pool = nn.AdaptiveAvgPool
  1197. self.head.flatten = nn.Flatten(1)
  1198. self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  1199. def forward_features(self, x):
  1200. x = self.stem(x)
  1201. x = self.stages(x)
  1202. x = self.norm_pre(x)
  1203. return x
  1204. def forward_head(self, x, pre_logits: bool = False):
  1205. # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
  1206. x = self.head.global_pool(x)
  1207. x = self.head.norm(x)
  1208. x = self.head.flatten(x)
  1209. x = self.head.drop(x)
  1210. return x if pre_logits else self.head.fc(x)
  1211. def forward(self, x, state=None):
  1212. x = self.forward_features(x)
  1213. x = self.forward_head(x)
  1214. return x, None
  1215. def _create_convnext3d(variant, pretrained=False, **kwargs):
  1216. model = build_model_with_cfg(
  1217. ConvNeXt3D,
  1218. variant,
  1219. pretrained,
  1220. default_cfg=default_cfgs[variant],
  1221. pretrained_filter_fn=checkpoint_filter_fn,
  1222. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  1223. **kwargs,
  1224. )
  1225. return model
  1226. @register_model
  1227. def convnext3d_tiny(pretrained=False, **kwargs):
  1228. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
  1229. model = _create_convnext3d('convnext_tiny', pretrained=pretrained, **model_args)
  1230. return model
  1231. def convnext_timm_tiny_2d_to_3d(model, state_dict, ignore_head=True, normalize=True):
  1232. """
  1233. inputs:
  1234. model: nn.Module, the from 'scratch' model
  1235. state_dict: dict, from the pretrained weights
  1236. ignore_head: bool, whether to inflate weights in the head (or keep scratch weights).
  1237. If number of classes changes (eg, imagenet to hmdb51), then you need to use this.
  1238. normalize: bool, if set to True (default), it inflates with a factor of 1, and if
  1239. set to False it inflates with a factor of 1/T where T is the temporal length for that kernel
  1240. return:
  1241. state_dict: dict, update with inflated weights
  1242. """
  1243. model_scratch_params_dict = dict(model.named_parameters())
  1244. prefix = list(state_dict.keys())[0].split('.')[0] # grab prefix in the keys for state_dict params
  1245. old_state_dict = copy.deepcopy(state_dict)
  1246. # loop through keys (in either)
  1247. # only check `weights`
  1248. # compare shapes btw 3d model and 2d model
  1249. # if, different, then broadcast
  1250. # then set the broadcasted version into the model value
  1251. for key in sorted(model_scratch_params_dict.keys()):
  1252. scratch_params = model_scratch_params_dict[key]
  1253. # need to add the predix 'model' in convnext
  1254. key_with_prefix = prefix + '.' + key
  1255. # make sure key is in the loaded params first, if not, then print it out
  1256. loaded_params = state_dict.get(key_with_prefix, None)
  1257. if 'time_weight' in key:
  1258. print("found time_weight parameter, train from scratch", key)
  1259. used_params = scratch_params
  1260. elif loaded_params is None:
  1261. # This should never happen for 2D -> 3D ConvNext
  1262. print("Missing key in pretrained model!", key_with_prefix)
  1263. raise Exception
  1264. # used_params = scratch_params
  1265. elif ignore_head and 'head' in key:
  1266. # ignore head weights
  1267. print("found head key / parameter, ignore", key)
  1268. used_params = scratch_params
  1269. elif len(scratch_params.shape) != len(loaded_params.shape):
  1270. # same keys, but inflating weights
  1271. print('key: shape DOES NOT MATCH', key)
  1272. print("scratch:", scratch_params.shape)
  1273. print("pretrain:", loaded_params.shape)
  1274. # need the index [-3], 3rd from last, the temporal dim
  1275. index = -3
  1276. temporal_dim = scratch_params.shape[index] # temporal len of kernel
  1277. temporal_kernel_factor = 1 if normalize else 1 / temporal_dim
  1278. used_params = repeat(temporal_kernel_factor*loaded_params, '... h w -> ... t h w', t=temporal_dim)
  1279. # loaded_params = temporal_kernel_factor * loaded_params.unsqueeze(index) # unsqueeze
  1280. # used_params = torch.cat(temporal_dim * [loaded_params], axis=index) # stack at this dim
  1281. else:
  1282. # print('key: shape MATCH', key) # loading matched weights
  1283. # used_params = loaded_params
  1284. continue
  1285. state_dict[key_with_prefix] = used_params
  1286. return state_dict
  1287. def convnext_timm_tiny_s4nd_2d_to_3d(model, state_dict, ignore_head=True, jank=False):
  1288. """
  1289. inputs:
  1290. model: nn.Module, the from 'scratch' model
  1291. state_dict: dict, from the pretrained weights
  1292. ignore_head: bool, whether to inflate weights in the head (or keep scratch weights).
  1293. If number of classes changes (eg, imagenet to hmdb51), then you need to use this.
  1294. return:
  1295. state_dict: dict, update with inflated weights
  1296. """
  1297. # model_scratch_params_dict = dict(model.named_parameters())
  1298. model_scratch_params_dict = {**dict(model.named_parameters()), **dict(model.named_buffers())}
  1299. prefix = list(state_dict.keys())[0].split('.')[0] # grab prefix in the keys for state_dict params
  1300. new_state_dict = copy.deepcopy(state_dict)
  1301. # for key in state_dict.keys():
  1302. # print(key)
  1303. # breakpoint()
  1304. for key in sorted(model_scratch_params_dict.keys()):
  1305. # need to add the predix 'model' in convnext
  1306. key_with_prefix = prefix + '.' + key
  1307. # HACK
  1308. old_key_with_prefix = key_with_prefix.replace("inv_w_real", "log_w_real")
  1309. # print(key)
  1310. # if '.kernel.L' in key:
  1311. # print(key, state_dict[old_key_with_prefix])
  1312. if '.kernel.0' in key:
  1313. # temporal dim is loaded from scratch
  1314. print("found .kernel.0:", key)
  1315. new_state_dict[key_with_prefix] = model_scratch_params_dict[key]
  1316. elif '.kernel.1' in key:
  1317. # This is the 1st kernel --> 0th kernel from pretrained model
  1318. print("FOUND .kernel.1, putting kernel 0 into kernel 1", key)
  1319. new_state_dict[key_with_prefix] = state_dict[old_key_with_prefix.replace(".kernel.1", ".kernel.0")]
  1320. elif '.kernel.2' in key:
  1321. print("FOUND .kernel.2, putting kernel 1 into kernel 2", key)
  1322. new_state_dict[key_with_prefix] = state_dict[old_key_with_prefix.replace(".kernel.2", ".kernel.1")]
  1323. elif ignore_head and 'head' in key:
  1324. # ignore head weights
  1325. print("found head key / parameter, ignore", key)
  1326. new_state_dict[key_with_prefix] = model_scratch_params_dict[key]
  1327. # keys match
  1328. else:
  1329. # check if mismatched shape, if so, need to inflate
  1330. # this covers cases where we did not use s4 (eg, optionally use conv2d in downsample or the stem)
  1331. try:
  1332. if model_scratch_params_dict[key].ndim != state_dict[old_key_with_prefix].ndim:
  1333. print("matching keys, but shapes mismatched! Need to inflate!", key)
  1334. # need the index [-3], 3rd from last, the temporal dim
  1335. index = -3
  1336. dim_len = model_scratch_params_dict[key].shape[index]
  1337. # loaded_params = state_dict[key_with_prefix].unsqueeze(index) # unsqueeze
  1338. # new_state_dict[key_with_prefix] = torch.cat(dim_len * [loaded_params], axis=index) # stack at this dim
  1339. new_state_dict[key_with_prefix] = repeat(state_dict[old_key_with_prefix], '... h w -> ... t h w', t=dim_len) # torch.cat(dim_len * [loaded_params], axis=index) # stack at this dim
  1340. else:
  1341. # matching case, shapes, match, load into new_state_dict as is
  1342. new_state_dict[key_with_prefix] = state_dict[old_key_with_prefix]
  1343. # something went wrong, the keys don't actually match (and they should)!
  1344. except:
  1345. print("unmatched key", key)
  1346. breakpoint()
  1347. # continue
  1348. return new_state_dict
  1349. def main():
  1350. model = convnext_tiny(
  1351. stem_type='new_s4nd_patch',
  1352. stem_channels=32,
  1353. stem_l_max=[16, 16],
  1354. downsample_type='s4nd',
  1355. downsample_glu=True,
  1356. stage_layers=[dict(dt_min=0.1, dt_max=1.0)] * 4,
  1357. stem_layer=dict(dt_min=0.1, dt_max=1.0, init='fourier'),
  1358. layer=dict(
  1359. _name_='s4nd',
  1360. bidirectional=True,
  1361. init='fourier',
  1362. dt_min=0.01,
  1363. dt_max=1.0,
  1364. n_ssm=1,
  1365. return_state=False,
  1366. ),
  1367. img_size=[224, 224],
  1368. )
  1369. # model = convnext_tiny(
  1370. # stem_type='patch',
  1371. # downsample_type=None,
  1372. # stage_layers=[None] * 4,
  1373. # img_size=[224, 224],
  1374. # )
  1375. vmodel = convnext3d_tiny(
  1376. stem_type='new_s4nd_patch',
  1377. stem_channels=32,
  1378. stem_l_max=[100, 16, 16],
  1379. downsample_type='s4nd',
  1380. downsample_glu=True,
  1381. stage_layers=[dict(dt_min=0.1, dt_max=1.0)] * 4,
  1382. stem_layer=dict(dt_min=0.1, dt_max=1.0, init='fourier'),
  1383. layer=dict(
  1384. _name_='s4nd',
  1385. bidirectional=True,
  1386. init='fourier',
  1387. dt_min=0.01,
  1388. dt_max=1.0,
  1389. n_ssm=1,
  1390. contract_version=1,
  1391. return_state=False,
  1392. ),
  1393. video_size=[100, 224, 224],
  1394. )
  1395. # vmodel = convnext3d_tiny(
  1396. # stem_type='patch',
  1397. # downsample_type=None,
  1398. # stage_layers=[None] * 4,
  1399. # video_size=[100, 224, 224],
  1400. # )
  1401. model.cuda()
  1402. x = torch.rand(1, 3, 224, 224).cuda()
  1403. y = model(x)[0]
  1404. print(y)
  1405. breakpoint()
  1406. vmodel.cuda()
  1407. x = torch.rand(1, 3, 50, 224, 224).cuda()
  1408. y = vmodel(x)[0]
  1409. print(y)
  1410. print(y.shape)
  1411. breakpoint()
  1412. # 3D Stem Conv options
  1413. # 1, 4, 4 kernel and stride
  1414. # 7, 4, 4 kernel and stride 2, 4, 4
  1415. # =======================================
  1416. # s4/configs/experiment/s4nd/convnext/convnext_timm_tiny_s4nd_imagenet.yaml
  1417. """
  1418. model:
  1419. img_size: ${dataset.__l_max}
  1420. drop_path_rate: 0.1
  1421. patch_size: 4 # 2 or 4, use for stem downsample factor
  1422. stem_channels: 32 # only used for s4nd stem currently
  1423. stem_type: new_s4nd_patch # options: patch (regular convnext), s4nd_patch, new_s4nd_patch (best), s4nd
  1424. stem_l_max: [16, 16] # stem_l_max=None, # len of l_max in stem (if using s4)
  1425. downsample_type: s4nd # eg, s4nd, null (for regular strided conv)
  1426. downsample_act: false
  1427. downsample_glu: True
  1428. conv_mlp: false
  1429. custom_ln: false # only used if conv_mlp=1, should benchmark to make sure this is faster/more mem efficient, also need to turn off weight decay
  1430. layer: # null means use regular conv2d in convnext
  1431. _name_: s4nd
  1432. d_state: 64
  1433. channels: 1
  1434. bidirectional: true
  1435. activation: null # mimics convnext style
  1436. final_act: none
  1437. initializer: null
  1438. weight_norm: false
  1439. dropout: 0
  1440. tie_dropout: ${oc.select:model.tie_dropout,null}
  1441. init: fourier
  1442. rank: 1
  1443. trank: 1
  1444. dt_min: 0.01
  1445. dt_max: 1.0
  1446. lr: 0.001
  1447. # length_correction: true
  1448. n_ssm: 1
  1449. deterministic: false # Special C init
  1450. l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to null and kernel will automatically resize
  1451. verbose: true
  1452. linear: true
  1453. return_state: false
  1454. bandlimit: null
  1455. contract_version: 0 # 0 is for 2d, 1 for 1d or 3d (or other)
  1456. stem_layer:
  1457. dt_min: 0.1
  1458. dt_max: 1.0
  1459. init: fourier
  1460. stage_layers:
  1461. - dt_min: 0.1
  1462. dt_max: 1.0
  1463. - dt_min: 0.1
  1464. dt_max: 1.0
  1465. - dt_min: 0.1
  1466. dt_max: 1.0
  1467. - dt_min: 0.1
  1468. dt_max: 1.0
  1469. """
  1470. # s4/configs/model/layer/s4nd.yaml
  1471. """
  1472. _name_: s4nd
  1473. d_state: 64
  1474. channels: 1
  1475. bidirectional: true
  1476. activation: gelu
  1477. final_act: glu
  1478. initializer: null
  1479. weight_norm: false
  1480. trank: 1
  1481. dropout: ${..dropout} # Same as null
  1482. tie_dropout: ${oc.select:model.tie_dropout,null}
  1483. init: legs
  1484. rank: 1
  1485. dt_min: 0.001
  1486. dt_max: 0.1
  1487. lr:
  1488. dt: 0.001
  1489. A: 0.001
  1490. B: 0.001
  1491. n_ssm: 1
  1492. deterministic: false # Special C init
  1493. l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to null and kernel will automatically resize
  1494. verbose: true
  1495. linear: false
  1496. """
  1497. def convnext_tiny_s4nd():
  1498. model = ConvNeXt(
  1499. depths=(3, 3, 9, 3),
  1500. dims=(96, 192, 384, 768),
  1501. img_size=(224, 224),
  1502. patch_size=4,
  1503. stem_channels=32,
  1504. stem_type="new_s4nd_patch",
  1505. stem_l_max=[16, 16],
  1506. downsample_act=False,
  1507. downsample_glu=True,
  1508. conv_mlp=False,
  1509. custom_ln=False,
  1510. layer=dict(
  1511. _name_="s4nd",
  1512. d_state=64,
  1513. channels=1,
  1514. bidirectional=True,
  1515. # activation="null",
  1516. # final_act="none",
  1517. final_act=None,
  1518. # initializer="null",
  1519. weight_norm=False,
  1520. dropout=0,
  1521. # tie_dropout="null",
  1522. init="fourier",
  1523. rank=1,
  1524. trank=1,
  1525. dt_min=0.01,
  1526. dt_max=1.0,
  1527. lr=0.001,
  1528. n_ssm=1,
  1529. deterministic=False,
  1530. # l_max="null",
  1531. verbose=True,
  1532. linear=True,
  1533. return_state=False,
  1534. # bandlimit="null",
  1535. contract_version=0,
  1536. ),
  1537. stem_layer=dict(
  1538. dt_min=0.1,
  1539. dt_max=1.0,
  1540. init="fourier",
  1541. ),
  1542. stage_layers=[
  1543. dict(dt_min=0.1, dt_max=1.0),
  1544. dict(dt_min=0.1, dt_max=1.0),
  1545. dict(dt_min=0.1, dt_max=1.0),
  1546. dict(dt_min=0.1, dt_max=1.0),
  1547. ],
  1548. )
  1549. def forward(self, x, resolution=1, state=None):
  1550. x = self.forward_features(x, resolution)
  1551. x = self.forward_head(x)
  1552. return x
  1553. model.forward = partial(forward, model)
  1554. return model
  1555. if __name__ == '__main__':
  1556. convnext_tiny_s4nd()