| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666 |
- """The original Vision Transformer (ViT) from timm.
- Copyright 2020 Ross Wightman.
- """
- import math
- import logging
- from functools import partial
- from collections import OrderedDict
- from copy import deepcopy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
- from timm.models.layers import PatchEmbed, Mlp, trunc_normal_, lecun_normal_
- from src.models.sequence.base import SequenceModule
- from src.models.nn import Normalization
- from src.models.sequence.backbones.block import SequenceResidualBlock
- from src.utils.config import to_list, to_dict
- _logger = logging.getLogger(__name__)
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000,
- 'input_size': (3, 224, 224),
- 'pool_size': None,
- # 'crop_pct': .9,
- # 'interpolation': 'bicubic',
- # 'fixed_input_size': True,
- # 'mean': IMAGENET_DEFAULT_MEAN,
- # 'std': IMAGENET_DEFAULT_STD,
- # 'first_conv': 'patch_embed.proj',
- 'classifier': 'head',
- **kwargs,
- }
- default_cfgs = {
- # patch models (my experiments)
- 'vit_small_patch16_224': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
- ),
- # patch models (weights ported from official Google JAX impl)
- 'vit_base_patch16_224': _cfg(
- url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
- mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
- ),
- }
- # class Block(nn.Module):
- # def __init__(
- # self,
- # dim,
- # num_heads,
- # mlp_ratio=4.,
- # qkv_bias=False,
- # qk_scale=None,
- # drop=0.,
- # attn_drop=0.,
- # drop_path=0.,
- # act_layer=nn.GELU,
- # norm_layer=nn.LayerNorm,
- # attnlinear_cfg=None,
- # mlp_cfg=None
- # ):
- # super().__init__()
- # self.norm1 = norm_layer(dim)
- # self.attn = AttentionSimple(
- # dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
- # linear_cfg=attnlinear_cfg)
- # self.drop_path = StochasticDepth(drop_path, mode='row')
- # self.norm2 = norm_layer(dim)
- # mlp_hidden_dim = int(dim * mlp_ratio)
- # if mlp_cfg is None:
- # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
- # else:
- # self.mlp = hydra.utils.instantiate(mlp_cfg, in_features=dim, hidden_features=mlp_hidden_dim,
- # act_layer=act_layer, drop=drop, _recursive_=False)
- # def forward(self, x):
- # x = x + self.drop_path(self.attn(self.norm1(x)))
- # x = x + self.drop_path(self.mlp(self.norm2(x)))
- # return x
- class VisionTransformer(SequenceModule):
- """ Vision Transformer
- A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- - https://arxiv.org/abs/2010.11929
- Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- - https://arxiv.org/abs/2012.12877
- """
- def __init__(
- self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- num_classes=1000,
- d_model=768,
- depth=12,
- # num_heads=12,
- expand=4,
- # qkv_bias=True,
- # qk_scale=None,
- representation_size=None,
- distilled=False,
- dropout=0.,
- # attn_drop_rate=0.,
- drop_path_rate=0.,
- embed_layer=PatchEmbed,
- norm='layer',
- # norm_layer=None,
- # act_layer=None,
- weight_init='',
- # attnlinear_cfg=None,
- # mlp_cfg=None,
- layer=None,
- # ff_cfg=None,
- transposed=False,
- layer_reps=1,
- use_pos_embed=False,
- use_cls_token=False,
- track_norms=False,
- ):
- """
- Args:
- img_size (int, tuple): input image size
- patch_size (int, tuple): patch size
- in_chans (int): number of input channels
- num_classes (int): number of classes for classification head
- d_model (int): embedding dimension
- depth (int): depth of transformer
- num_heads (int): number of attention heads
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
- qkv_bias (bool): enable bias for qkv if True
- qk_scale (float): override default qk scale of head_dim ** -0.5 if set
- representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
- distilled (bool): model includes a distillation token and head as in DeiT models
- dropout (float): dropout rate
- attn_drop_rate (float): attention dropout rate
- drop_path_rate (float): stochastic depth rate
- embed_layer (nn.Module): patch embedding layer
- norm_layer: (nn.Module): normalization layer
- weight_init: (str): weight init scheme
- """
- super().__init__()
- self.num_classes = num_classes
- self.num_features = self.d_model = d_model # num_features for consistency with other models
- self.num_tokens = 2 if distilled else 1
- self.use_pos_embed = use_pos_embed
- self.use_cls_token = use_cls_token
- # norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
- # act_layer = act_layer or nn.GELU
- self.track_norms = track_norms
- self.patch_embed = embed_layer(
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=d_model,
- )
- num_patches = self.patch_embed.num_patches
- self.cls_token = None
- self.dist_token = None
- if use_cls_token:
- self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
- self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) if distilled else None
- else:
- assert not distilled, 'Distillation token not supported without class token'
- self.pos_embed = None
- if use_pos_embed:
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, d_model))
- self.pos_drop = nn.Dropout(p=dropout)
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
- # self.blocks = nn.Sequential(*[
- # Block(
- # dim=d_model, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
- # drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
- # attnlinear_cfg=attnlinear_cfg, mlp_cfg=mlp_cfg)
- # for i in range(depth)
- # ])
- self.transposed = transposed
- layer = to_list(layer, recursive=False) * layer_reps
- # Some special arguments are passed into each layer
- for _layer in layer:
- # If layers don't specify dropout, add it
- if _layer.get('dropout', None) is None:
- _layer['dropout'] = dropout
- # Ensure all layers are shaped the same way
- _layer['transposed'] = transposed
- # # Layer arguments
- # layer_cfg = layer.copy()
- # layer_cfg['dropout'] = dropout
- # layer_cfg['transposed'] = self.transposed
- # layer_cfg['initializer'] = None
- # # layer_cfg['l_max'] = L
- # print("layer config", layer_cfg)
- # Config for the inverted bottleneck
- ff_cfg = {
- '_name_': 'ffn',
- 'expand': int(expand),
- 'transposed': self.transposed,
- 'activation': 'gelu',
- 'initializer': None,
- 'dropout': dropout,
- }
- blocks = []
- for i in range(depth):
- for _layer in layer:
- blocks.append(
- SequenceResidualBlock(
- d_input=d_model,
- i_layer=i,
- prenorm=True,
- dropout=dropout,
- layer=_layer,
- residual='R',
- norm=norm,
- pool=None,
- drop_path=dpr[i],
- )
- )
- if expand > 0:
- blocks.append(
- SequenceResidualBlock(
- d_input=d_model,
- i_layer=i,
- prenorm=True,
- dropout=dropout,
- layer=ff_cfg,
- residual='R',
- norm=norm,
- pool=None,
- drop_path=dpr[i],
- )
- )
- self.blocks = nn.Sequential(*blocks)
- # self.norm = norm_layer(d_model)
- if norm is None:
- self.norm = None
- elif isinstance(norm, str):
- self.norm = Normalization(d_model, transposed=self.transposed, _name_=norm)
- else:
- self.norm = Normalization(d_model, transposed=self.transposed, **norm)
- # Representation layer: generally defaults to nn.Identity()
- if representation_size and not distilled:
- self.num_features = representation_size
- self.pre_logits = nn.Sequential(OrderedDict([
- ('fc', nn.Linear(d_model, representation_size)),
- ('act', nn.Tanh())
- ]))
- else:
- self.pre_logits = nn.Identity()
- # Classifier head(s): TODO: move to decoder
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
- self.head_dist = None
- if distilled:
- self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity()
- # Weight init
- assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
- head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
- if self.pos_embed is not None:
- trunc_normal_(self.pos_embed, std=.02)
- if self.dist_token is not None:
- trunc_normal_(self.dist_token, std=.02)
- if weight_init.startswith('jax'):
- # leave cls token as zeros to match jax impl
- for n, m in self.named_modules():
- _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
- else:
- if self.cls_token is not None:
- trunc_normal_(self.cls_token, std=.02)
- self.apply(_init_vit_weights)
- def _init_weights(self, m):
- # this fn left here for compat with downstream users
- _init_vit_weights(m)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'pos_embed', 'cls_token', 'dist_token'}
- # def get_classifier(self):
- # if self.dist_token is None:
- # return self.head
- # else:
- # return self.head, self.head_dist
- # def reset_classifier(self, num_classes, global_pool=''):
- # self.num_classes = num_classes
- # self.head = nn.Linear(self.d_model, num_classes) if num_classes > 0 else nn.Identity()
- # if self.num_tokens == 2:
- # self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity()
- def forward_features(self, x):
- # TODO: move to encoder
- x = self.patch_embed(x)
- if self.use_cls_token:
- cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- if self.dist_token is None:
- x = torch.cat((cls_token, x), dim=1)
- else:
- x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
- if self.use_pos_embed:
- x = self.pos_drop(x + self.pos_embed)
- if self.track_norms: output_norms = [torch.mean(x.detach() ** 2)]
- for block in self.blocks:
- x, _ = block(x)
- if self.track_norms: output_norms.append(torch.mean(x.detach() ** 2))
- x = self.norm(x)
- if self.track_norms:
- metrics = to_dict(output_norms, recursive=False)
- self.metrics = {f'norm/{i}': v for i, v in metrics.items()}
- if self.dist_token is None:
- if self.use_cls_token:
- return self.pre_logits(x[:, 0])
- else:
- # pooling: TODO move to decoder
- return self.pre_logits(x.mean(1))
- else:
- return x[:, 0], x[:, 1]
- def forward(self, x, rate=1.0, resolution=None, state=None):
- x = self.forward_features(x)
- if self.head_dist is not None:
- x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
- if self.training and not torch.jit.is_scripting():
- # during inference, return the average of both classifier predictions
- return x, x_dist
- else:
- return (x + x_dist) / 2
- else:
- x = self.head(x)
- return x, None
- def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
- """ ViT weight initialization
- * When called without n, head_bias, jax_impl args it will behave exactly the same
- as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
- * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
- """
- if isinstance(m, (nn.Linear)):
- if n.startswith('head'):
- nn.init.zeros_(m.weight)
- nn.init.constant_(m.bias, head_bias)
- elif n.startswith('pre_logits'):
- lecun_normal_(m.weight)
- nn.init.zeros_(m.bias)
- else:
- if jax_impl:
- nn.init.xavier_uniform_(m.weight)
- if m.bias is not None:
- if 'mlp' in n:
- nn.init.normal_(m.bias, std=1e-6)
- else:
- nn.init.zeros_(m.bias)
- else:
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- dense_init_fn_ = partial(trunc_normal_, std=.02)
- if isinstance(m, nn.Linear):
- dense_init_fn_(m.weight)
- # elif isinstance(m, (BlockSparseLinear, BlockdiagLinear, LowRank)):
- # m.set_weights_from_dense_init(dense_init_fn_)
- elif jax_impl and isinstance(m, nn.Conv2d):
- # NOTE conv was left to pytorch default in my original init
- lecun_normal_(m.weight)
- if m.bias is not None:
- nn.init.zeros_(m.bias)
- elif isinstance(m, nn.LayerNorm):
- nn.init.zeros_(m.bias)
- nn.init.ones_(m.weight)
- def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
- # Rescale the grid of position embeddings when loading from state_dict. Adapted from
- # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
- _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
- ntok_new = posemb_new.shape[1]
- if num_tokens:
- posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
- ntok_new -= num_tokens
- else:
- posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
- gs_old = int(math.sqrt(len(posemb_grid)))
- if not len(gs_new): # backwards compatibility
- gs_new = [int(math.sqrt(ntok_new))] * 2
- assert len(gs_new) >= 2
- _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
- posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
- return posemb
- def checkpoint_filter_fn(state_dict, model):
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
- out_dict = {}
- if 'model' in state_dict:
- # For deit models
- state_dict = state_dict['model']
- for k, v in state_dict.items():
- if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
- # For old models that I trained prior to conv based patchification
- O, I, H, W = model.patch_embed.proj.weight.shape
- v = v.reshape(O, -1, H, W)
- elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
- # To resize pos embedding when using model at different size from pretrained weights
- v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1),
- model.patch_embed.grid_size)
- out_dict[k] = v
- return out_dict
- def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
- if default_cfg is None:
- default_cfg = deepcopy(default_cfgs[variant])
- overlay_external_default_cfg(default_cfg, kwargs)
- default_num_classes = default_cfg['num_classes']
- default_img_size = default_cfg['input_size'][-2:]
- num_classes = kwargs.pop('num_classes', default_num_classes)
- img_size = kwargs.pop('img_size', default_img_size)
- repr_size = kwargs.pop('representation_size', None)
- if repr_size is not None and num_classes != default_num_classes:
- # Remove representation layer if fine-tuning. This may not always be the desired action,
- # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
- _logger.warning("Removing representation layer for fine-tuning.")
- repr_size = None
- if kwargs.get('features_only', None):
- raise RuntimeError('features_only not implemented for Vision Transformer models.')
- model = build_model_with_cfg(
- VisionTransformer,
- variant,
- pretrained,
- default_cfg=default_cfg,
- img_size=img_size,
- num_classes=num_classes,
- representation_size=repr_size,
- pretrained_filter_fn=checkpoint_filter_fn,
- **kwargs)
- return model
- def vit_small_patch16_224(pretrained=False, **kwargs):
- """ Tri's custom 'small' ViT model. d_model=768, depth=8, num_heads=8, mlp_ratio=3.
- NOTE:
- * this differs from the DeiT based 'small' definitions with d_model=384, depth=12, num_heads=6
- * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
- """
- print(kwargs)
- model_kwargs = dict(
- patch_size=16,
- d_model=768,
- depth=8,
- # num_heads=8,
- expand=3,
- # qkv_bias=False,
- norm='layer',
- # norm_layer=nn.LayerNorm,
- )
- model_kwargs = {
- **model_kwargs,
- **kwargs,
- }
- if pretrained:
- # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
- model_kwargs.setdefault('qk_scale', 768 ** -0.5)
- model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
- return model
- def vit_base_patch16_224(pretrained=False, **kwargs):
- """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
- ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
- """
- model_kwargs = dict(
- patch_size=16,
- d_model=768,
- depth=12,
- # num_heads=12,
- )
- model_kwargs = {
- **model_kwargs,
- **kwargs,
- }
- model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
- return model
- # ============================
- # s4/configs/experiment/s4nd/vit/vit_b_16_s4_imagenet_v2.yaml
- """
- model:
- _name_: vit_b_16
- dropout: 0.0
- drop_path_rate: 0.1
- d_model: 768
- depth: 12
- expand: 4
- norm: layer
- layer_reps: 1
- use_cls_token: false
- use_pos_embed: false
- layer:
- d_state: 64
- final_act: glu
- bidirectional: true
- channels: 2
- lr: 0.001
- n_ssm: 1
- contract_version: 1 # 0 is for 2d, 1 for 1d or 3d (or other)
- """
- # s4/configs/model/layer/s4nd.yaml
- """
- _name_: s4nd
- d_state: 64
- channels: 1
- bidirectional: true
- activation: gelu
- final_act: glu
- initializer: null
- weight_norm: false
- trank: 1
- dropout: ${..dropout} # Same as null
- tie_dropout: ${oc.select:model.tie_dropout,null}
- init: legs
- rank: 1
- dt_min: 0.001
- dt_max: 0.1
- lr:
- dt: 0.001
- A: 0.001
- B: 0.001
- n_ssm: 1
- deterministic: false # Special C init
- l_max: ${oc.select:dataset.__l_max,null} # Grab dataset length if exists, otherwise set to null and kernel will automatically resize
- verbose: true
- linear: false
- """
- def vit_base_s4nd():
- """
- model:
- _name_: vit_b_16
- dropout: 0.0
- drop_path_rate: 0.1
- d_model: 768
- depth: 12
- expand: 4
- norm: layer
- layer_reps: 1
- use_cls_token: false
- use_pos_embed: false
- layer:
- d_state: 64
- final_act: glu
- bidirectional: true
- channels: 2
- lr: 0.001
- n_ssm: 1
- contract_version: 1 # 0 is for 2d, 1 for 1d or 3d (or other)
- """
- """
- defaults:
- - layer: vit
- patch_size: 16
- d_model: 768
- dropout: 0.0
- drop_path_rate: 0.0
- depth: 8
- expand: 3
- norm: layer
- use_pos_embed: true
- use_cls_token: true
- """
- model = VisionTransformer(
- dropout=0.0,
- drop_path_rate=0.1,
- d_model=768,
- depth=12,
- expand=4,
- norm="layer",
- layer_reps=1,
- use_cls_token=False,
- use_pos_embed=False,
- layer=dict(
- _name_="s4nd",
- d_state=64,
- final_act="glu",
- bidirectional=True,
- channels=2,
- lr=0.001,
- n_ssm=1,
- contract_version=1,
- # contract_version=0,
- activation="gelu",
- initializer=None,
- weight_norm=False,
- trank=1,
- dropout=0,
- tie_dropout=0,
- init="legs",
- dt_min=0.001,
- dt_max=0.1,
- deterministic=False,
- l_max=None,
- verbose=True,
- linear=False,
- )
- )
-
- def forward(self, x, rate=1.0, resolution=None, state=None):
- x = self.forward_features(x)
- if self.head_dist is not None:
- x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
- if self.training and not torch.jit.is_scripting():
- # during inference, return the average of both classifier predictions
- return x, x_dist
- else:
- return (x + x_dist) / 2
- else:
- x = self.head(x)
- return x
-
- model.forward = partial(forward, model)
-
- return model
|