vit_all.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. """The original Vision Transformer (ViT) from timm.
  2. Copyright 2020 Ross Wightman.
  3. """
  4. import math
  5. import logging
  6. from functools import partial
  7. from collections import OrderedDict
  8. from copy import deepcopy
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
  13. from timm.models.layers import PatchEmbed, Mlp, trunc_normal_, lecun_normal_
  14. from src.models.sequence.base import SequenceModule
  15. from src.models.nn import Normalization
  16. from src.models.sequence.backbones.block import SequenceResidualBlock
  17. from src.utils.config import to_list, to_dict
  18. _logger = logging.getLogger(__name__)
  19. def _cfg(url='', **kwargs):
  20. return {
  21. 'url': url,
  22. 'num_classes': 1000,
  23. 'input_size': (3, 224, 224),
  24. 'pool_size': None,
  25. # 'crop_pct': .9,
  26. # 'interpolation': 'bicubic',
  27. # 'fixed_input_size': True,
  28. # 'mean': IMAGENET_DEFAULT_MEAN,
  29. # 'std': IMAGENET_DEFAULT_STD,
  30. # 'first_conv': 'patch_embed.proj',
  31. 'classifier': 'head',
  32. **kwargs,
  33. }
  34. default_cfgs = {
  35. # patch models (my experiments)
  36. 'vit_small_patch16_224': _cfg(
  37. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
  38. ),
  39. # patch models (weights ported from official Google JAX impl)
  40. 'vit_base_patch16_224': _cfg(
  41. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
  42. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  43. ),
  44. }
  45. # class Block(nn.Module):
  46. # def __init__(
  47. # self,
  48. # dim,
  49. # num_heads,
  50. # mlp_ratio=4.,
  51. # qkv_bias=False,
  52. # qk_scale=None,
  53. # drop=0.,
  54. # attn_drop=0.,
  55. # drop_path=0.,
  56. # act_layer=nn.GELU,
  57. # norm_layer=nn.LayerNorm,
  58. # attnlinear_cfg=None,
  59. # mlp_cfg=None
  60. # ):
  61. # super().__init__()
  62. # self.norm1 = norm_layer(dim)
  63. # self.attn = AttentionSimple(
  64. # dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
  65. # linear_cfg=attnlinear_cfg)
  66. # self.drop_path = StochasticDepth(drop_path, mode='row')
  67. # self.norm2 = norm_layer(dim)
  68. # mlp_hidden_dim = int(dim * mlp_ratio)
  69. # if mlp_cfg is None:
  70. # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  71. # else:
  72. # self.mlp = hydra.utils.instantiate(mlp_cfg, in_features=dim, hidden_features=mlp_hidden_dim,
  73. # act_layer=act_layer, drop=drop, _recursive_=False)
  74. # def forward(self, x):
  75. # x = x + self.drop_path(self.attn(self.norm1(x)))
  76. # x = x + self.drop_path(self.mlp(self.norm2(x)))
  77. # return x
  78. class VisionTransformer(SequenceModule):
  79. """ Vision Transformer
  80. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
  81. - https://arxiv.org/abs/2010.11929
  82. Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
  83. - https://arxiv.org/abs/2012.12877
  84. """
  85. def __init__(
  86. self,
  87. img_size=224,
  88. patch_size=16,
  89. in_chans=3,
  90. num_classes=1000,
  91. d_model=768,
  92. depth=12,
  93. # num_heads=12,
  94. expand=4,
  95. # qkv_bias=True,
  96. # qk_scale=None,
  97. representation_size=None,
  98. distilled=False,
  99. dropout=0.,
  100. # attn_drop_rate=0.,
  101. drop_path_rate=0.,
  102. embed_layer=PatchEmbed,
  103. norm='layer',
  104. # norm_layer=None,
  105. # act_layer=None,
  106. weight_init='',
  107. # attnlinear_cfg=None,
  108. # mlp_cfg=None,
  109. layer=None,
  110. # ff_cfg=None,
  111. transposed=False,
  112. layer_reps=1,
  113. use_pos_embed=False,
  114. use_cls_token=False,
  115. track_norms=False,
  116. ):
  117. """
  118. Args:
  119. img_size (int, tuple): input image size
  120. patch_size (int, tuple): patch size
  121. in_chans (int): number of input channels
  122. num_classes (int): number of classes for classification head
  123. d_model (int): embedding dimension
  124. depth (int): depth of transformer
  125. num_heads (int): number of attention heads
  126. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  127. qkv_bias (bool): enable bias for qkv if True
  128. qk_scale (float): override default qk scale of head_dim ** -0.5 if set
  129. representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
  130. distilled (bool): model includes a distillation token and head as in DeiT models
  131. dropout (float): dropout rate
  132. attn_drop_rate (float): attention dropout rate
  133. drop_path_rate (float): stochastic depth rate
  134. embed_layer (nn.Module): patch embedding layer
  135. norm_layer: (nn.Module): normalization layer
  136. weight_init: (str): weight init scheme
  137. """
  138. super().__init__()
  139. self.num_classes = num_classes
  140. self.num_features = self.d_model = d_model # num_features for consistency with other models
  141. self.num_tokens = 2 if distilled else 1
  142. self.use_pos_embed = use_pos_embed
  143. self.use_cls_token = use_cls_token
  144. # norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
  145. # act_layer = act_layer or nn.GELU
  146. self.track_norms = track_norms
  147. self.patch_embed = embed_layer(
  148. img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=d_model,
  149. )
  150. num_patches = self.patch_embed.num_patches
  151. self.cls_token = None
  152. self.dist_token = None
  153. if use_cls_token:
  154. self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
  155. self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) if distilled else None
  156. else:
  157. assert not distilled, 'Distillation token not supported without class token'
  158. self.pos_embed = None
  159. if use_pos_embed:
  160. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, d_model))
  161. self.pos_drop = nn.Dropout(p=dropout)
  162. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  163. # self.blocks = nn.Sequential(*[
  164. # Block(
  165. # dim=d_model, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  166. # drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
  167. # attnlinear_cfg=attnlinear_cfg, mlp_cfg=mlp_cfg)
  168. # for i in range(depth)
  169. # ])
  170. self.transposed = transposed
  171. layer = to_list(layer, recursive=False) * layer_reps
  172. # Some special arguments are passed into each layer
  173. for _layer in layer:
  174. # If layers don't specify dropout, add it
  175. if _layer.get('dropout', None) is None:
  176. _layer['dropout'] = dropout
  177. # Ensure all layers are shaped the same way
  178. _layer['transposed'] = transposed
  179. # # Layer arguments
  180. # layer_cfg = layer.copy()
  181. # layer_cfg['dropout'] = dropout
  182. # layer_cfg['transposed'] = self.transposed
  183. # layer_cfg['initializer'] = None
  184. # # layer_cfg['l_max'] = L
  185. # print("layer config", layer_cfg)
  186. # Config for the inverted bottleneck
  187. ff_cfg = {
  188. '_name_': 'ffn',
  189. 'expand': int(expand),
  190. 'transposed': self.transposed,
  191. 'activation': 'gelu',
  192. 'initializer': None,
  193. 'dropout': dropout,
  194. }
  195. blocks = []
  196. for i in range(depth):
  197. for _layer in layer:
  198. blocks.append(
  199. SequenceResidualBlock(
  200. d_input=d_model,
  201. i_layer=i,
  202. prenorm=True,
  203. dropout=dropout,
  204. layer=_layer,
  205. residual='R',
  206. norm=norm,
  207. pool=None,
  208. drop_path=dpr[i],
  209. )
  210. )
  211. if expand > 0:
  212. blocks.append(
  213. SequenceResidualBlock(
  214. d_input=d_model,
  215. i_layer=i,
  216. prenorm=True,
  217. dropout=dropout,
  218. layer=ff_cfg,
  219. residual='R',
  220. norm=norm,
  221. pool=None,
  222. drop_path=dpr[i],
  223. )
  224. )
  225. self.blocks = nn.Sequential(*blocks)
  226. # self.norm = norm_layer(d_model)
  227. if norm is None:
  228. self.norm = None
  229. elif isinstance(norm, str):
  230. self.norm = Normalization(d_model, transposed=self.transposed, _name_=norm)
  231. else:
  232. self.norm = Normalization(d_model, transposed=self.transposed, **norm)
  233. # Representation layer: generally defaults to nn.Identity()
  234. if representation_size and not distilled:
  235. self.num_features = representation_size
  236. self.pre_logits = nn.Sequential(OrderedDict([
  237. ('fc', nn.Linear(d_model, representation_size)),
  238. ('act', nn.Tanh())
  239. ]))
  240. else:
  241. self.pre_logits = nn.Identity()
  242. # Classifier head(s): TODO: move to decoder
  243. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  244. self.head_dist = None
  245. if distilled:
  246. self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity()
  247. # Weight init
  248. assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
  249. head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
  250. if self.pos_embed is not None:
  251. trunc_normal_(self.pos_embed, std=.02)
  252. if self.dist_token is not None:
  253. trunc_normal_(self.dist_token, std=.02)
  254. if weight_init.startswith('jax'):
  255. # leave cls token as zeros to match jax impl
  256. for n, m in self.named_modules():
  257. _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
  258. else:
  259. if self.cls_token is not None:
  260. trunc_normal_(self.cls_token, std=.02)
  261. self.apply(_init_vit_weights)
  262. def _init_weights(self, m):
  263. # this fn left here for compat with downstream users
  264. _init_vit_weights(m)
  265. @torch.jit.ignore
  266. def no_weight_decay(self):
  267. return {'pos_embed', 'cls_token', 'dist_token'}
  268. # def get_classifier(self):
  269. # if self.dist_token is None:
  270. # return self.head
  271. # else:
  272. # return self.head, self.head_dist
  273. # def reset_classifier(self, num_classes, global_pool=''):
  274. # self.num_classes = num_classes
  275. # self.head = nn.Linear(self.d_model, num_classes) if num_classes > 0 else nn.Identity()
  276. # if self.num_tokens == 2:
  277. # self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity()
  278. def forward_features(self, x):
  279. # TODO: move to encoder
  280. x = self.patch_embed(x)
  281. if self.use_cls_token:
  282. cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  283. if self.dist_token is None:
  284. x = torch.cat((cls_token, x), dim=1)
  285. else:
  286. x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
  287. if self.use_pos_embed:
  288. x = self.pos_drop(x + self.pos_embed)
  289. if self.track_norms: output_norms = [torch.mean(x.detach() ** 2)]
  290. for block in self.blocks:
  291. x, _ = block(x)
  292. if self.track_norms: output_norms.append(torch.mean(x.detach() ** 2))
  293. x = self.norm(x)
  294. if self.track_norms:
  295. metrics = to_dict(output_norms, recursive=False)
  296. self.metrics = {f'norm/{i}': v for i, v in metrics.items()}
  297. if self.dist_token is None:
  298. if self.use_cls_token:
  299. return self.pre_logits(x[:, 0])
  300. else:
  301. # pooling: TODO move to decoder
  302. return self.pre_logits(x.mean(1))
  303. else:
  304. return x[:, 0], x[:, 1]
  305. def forward(self, x, rate=1.0, resolution=None, state=None):
  306. x = self.forward_features(x)
  307. if self.head_dist is not None:
  308. x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
  309. if self.training and not torch.jit.is_scripting():
  310. # during inference, return the average of both classifier predictions
  311. return x, x_dist
  312. else:
  313. return (x + x_dist) / 2
  314. else:
  315. x = self.head(x)
  316. return x, None
  317. def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
  318. """ ViT weight initialization
  319. * When called without n, head_bias, jax_impl args it will behave exactly the same
  320. as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
  321. * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
  322. """
  323. if isinstance(m, (nn.Linear)):
  324. if n.startswith('head'):
  325. nn.init.zeros_(m.weight)
  326. nn.init.constant_(m.bias, head_bias)
  327. elif n.startswith('pre_logits'):
  328. lecun_normal_(m.weight)
  329. nn.init.zeros_(m.bias)
  330. else:
  331. if jax_impl:
  332. nn.init.xavier_uniform_(m.weight)
  333. if m.bias is not None:
  334. if 'mlp' in n:
  335. nn.init.normal_(m.bias, std=1e-6)
  336. else:
  337. nn.init.zeros_(m.bias)
  338. else:
  339. if m.bias is not None:
  340. nn.init.zeros_(m.bias)
  341. dense_init_fn_ = partial(trunc_normal_, std=.02)
  342. if isinstance(m, nn.Linear):
  343. dense_init_fn_(m.weight)
  344. # elif isinstance(m, (BlockSparseLinear, BlockdiagLinear, LowRank)):
  345. # m.set_weights_from_dense_init(dense_init_fn_)
  346. elif jax_impl and isinstance(m, nn.Conv2d):
  347. # NOTE conv was left to pytorch default in my original init
  348. lecun_normal_(m.weight)
  349. if m.bias is not None:
  350. nn.init.zeros_(m.bias)
  351. elif isinstance(m, nn.LayerNorm):
  352. nn.init.zeros_(m.bias)
  353. nn.init.ones_(m.weight)
  354. def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
  355. # Rescale the grid of position embeddings when loading from state_dict. Adapted from
  356. # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
  357. _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
  358. ntok_new = posemb_new.shape[1]
  359. if num_tokens:
  360. posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
  361. ntok_new -= num_tokens
  362. else:
  363. posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
  364. gs_old = int(math.sqrt(len(posemb_grid)))
  365. if not len(gs_new): # backwards compatibility
  366. gs_new = [int(math.sqrt(ntok_new))] * 2
  367. assert len(gs_new) >= 2
  368. _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
  369. posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
  370. posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
  371. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
  372. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  373. return posemb
  374. def checkpoint_filter_fn(state_dict, model):
  375. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  376. out_dict = {}
  377. if 'model' in state_dict:
  378. # For deit models
  379. state_dict = state_dict['model']
  380. for k, v in state_dict.items():
  381. if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
  382. # For old models that I trained prior to conv based patchification
  383. O, I, H, W = model.patch_embed.proj.weight.shape
  384. v = v.reshape(O, -1, H, W)
  385. elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
  386. # To resize pos embedding when using model at different size from pretrained weights
  387. v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1),
  388. model.patch_embed.grid_size)
  389. out_dict[k] = v
  390. return out_dict
  391. def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
  392. if default_cfg is None:
  393. default_cfg = deepcopy(default_cfgs[variant])
  394. overlay_external_default_cfg(default_cfg, kwargs)
  395. default_num_classes = default_cfg['num_classes']
  396. default_img_size = default_cfg['input_size'][-2:]
  397. num_classes = kwargs.pop('num_classes', default_num_classes)
  398. img_size = kwargs.pop('img_size', default_img_size)
  399. repr_size = kwargs.pop('representation_size', None)
  400. if repr_size is not None and num_classes != default_num_classes:
  401. # Remove representation layer if fine-tuning. This may not always be the desired action,
  402. # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
  403. _logger.warning("Removing representation layer for fine-tuning.")
  404. repr_size = None
  405. if kwargs.get('features_only', None):
  406. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  407. model = build_model_with_cfg(
  408. VisionTransformer,
  409. variant,
  410. pretrained,
  411. default_cfg=default_cfg,
  412. img_size=img_size,
  413. num_classes=num_classes,
  414. representation_size=repr_size,
  415. pretrained_filter_fn=checkpoint_filter_fn,
  416. **kwargs)
  417. return model
  418. def vit_small_patch16_224(pretrained=False, **kwargs):
  419. """ Tri's custom 'small' ViT model. d_model=768, depth=8, num_heads=8, mlp_ratio=3.
  420. NOTE:
  421. * this differs from the DeiT based 'small' definitions with d_model=384, depth=12, num_heads=6
  422. * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
  423. """
  424. print(kwargs)
  425. model_kwargs = dict(
  426. patch_size=16,
  427. d_model=768,
  428. depth=8,
  429. # num_heads=8,
  430. expand=3,
  431. # qkv_bias=False,
  432. norm='layer',
  433. # norm_layer=nn.LayerNorm,
  434. )
  435. model_kwargs = {
  436. **model_kwargs,
  437. **kwargs,
  438. }
  439. if pretrained:
  440. # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
  441. model_kwargs.setdefault('qk_scale', 768 ** -0.5)
  442. model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
  443. return model
  444. def vit_base_patch16_224(pretrained=False, **kwargs):
  445. """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  446. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  447. """
  448. model_kwargs = dict(
  449. patch_size=16,
  450. d_model=768,
  451. depth=12,
  452. # num_heads=12,
  453. )
  454. model_kwargs = {
  455. **model_kwargs,
  456. **kwargs,
  457. }
  458. model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
  459. return model
  460. # ============================
  461. # s4/configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml
  462. """
  463. model:
  464. _name_: vit_b_16
  465. dropout: 0.0
  466. drop_path_rate: 0.1
  467. d_model: 768
  468. depth: 12
  469. expand: 4
  470. norm: layer
  471. layer_reps: 1
  472. use_cls_token: false
  473. use_pos_embed: false
  474. layer:
  475. d_state: 64
  476. final_act: glu
  477. bidirectional: true
  478. channels: 2
  479. lr: 0.001
  480. n_ssm: 1
  481. contract_version: 1 # 0 is for 2d, 1 for 1d or 3d (or other)
  482. """
  483. # s4/configs/model/layer/s4nd.yaml
  484. """
  485. _name_: s4nd
  486. d_state: 64
  487. channels: 1
  488. bidirectional: true
  489. activation: gelu
  490. final_act: glu
  491. initializer: null
  492. weight_norm: false
  493. trank: 1
  494. dropout: ${..dropout} # Same as null
  495. tie_dropout: ${oc.select:model.tie_dropout,null}
  496. init: legs
  497. rank: 1
  498. dt_min: 0.001
  499. dt_max: 0.1
  500. lr:
  501. dt: 0.001
  502. A: 0.001
  503. B: 0.001
  504. n_ssm: 1
  505. deterministic: false # Special C init
  506. l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to null and kernel will automatically resize
  507. verbose: true
  508. linear: false
  509. """
  510. def vit_base_s4nd():
  511. """
  512. model:
  513. _name_: vit_b_16
  514. dropout: 0.0
  515. drop_path_rate: 0.1
  516. d_model: 768
  517. depth: 12
  518. expand: 4
  519. norm: layer
  520. layer_reps: 1
  521. use_cls_token: false
  522. use_pos_embed: false
  523. layer:
  524. d_state: 64
  525. final_act: glu
  526. bidirectional: true
  527. channels: 2
  528. lr: 0.001
  529. n_ssm: 1
  530. contract_version: 1 # 0 is for 2d, 1 for 1d or 3d (or other)
  531. """
  532. """
  533. defaults:
  534. - layer: vit
  535. patch_size: 16
  536. d_model: 768
  537. dropout: 0.0
  538. drop_path_rate: 0.0
  539. depth: 8
  540. expand: 3
  541. norm: layer
  542. use_pos_embed: true
  543. use_cls_token: true
  544. """
  545. model = VisionTransformer(
  546. dropout=0.0,
  547. drop_path_rate=0.1,
  548. d_model=768,
  549. depth=12,
  550. expand=4,
  551. norm="layer",
  552. layer_reps=1,
  553. use_cls_token=False,
  554. use_pos_embed=False,
  555. layer=dict(
  556. _name_="s4nd",
  557. d_state=64,
  558. final_act="glu",
  559. bidirectional=True,
  560. channels=2,
  561. lr=0.001,
  562. n_ssm=1,
  563. contract_version=1,
  564. # contract_version=0,
  565. activation="gelu",
  566. initializer=None,
  567. weight_norm=False,
  568. trank=1,
  569. dropout=0,
  570. tie_dropout=0,
  571. init="legs",
  572. dt_min=0.001,
  573. dt_max=0.1,
  574. deterministic=False,
  575. l_max=None,
  576. verbose=True,
  577. linear=False,
  578. )
  579. )
  580. def forward(self, x, rate=1.0, resolution=None, state=None):
  581. x = self.forward_features(x)
  582. if self.head_dist is not None:
  583. x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
  584. if self.training and not torch.jit.is_scripting():
  585. # during inference, return the average of both classifier predictions
  586. return x, x_dist
  587. else:
  588. return (x + x_dist) / 2
  589. else:
  590. x = self.head(x)
  591. return x
  592. model.forward = partial(forward, model)
  593. return model