| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722 |
- import os
- import sys
- import torch
- import random
- import math
- from functools import partial
- from utils import import_abspy, EffectiveReceiptiveField, visualize
- HOME = os.environ["HOME"].rstrip("/")
- class ExtraDev:
- # 5.162112298177406 30007057
- # 17.069485400571516 91157224
- def flops_s4nd(size=224, scale="ctiny"):
- import math
- from fvcore.nn import flop_count
- specpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./convnexts4nd")
- sys.path.insert(0, specpath)
- import timm; assert timm.__version__ == "0.5.4"
- import structured_kernels
- model1 = import_abspy("vit_all", f"{os.path.dirname(__file__)}/convnexts4nd")
- vitb = model1.vit_base_s4nd
- vitb = vitb().cuda().eval()
- model2 = import_abspy("convnext_timm", f"{os.path.dirname(__file__)}/convnexts4nd")
- ctiny = model2.convnext_tiny_s4nd
- ctiny = ctiny().cuda().eval()
- sys.path = sys.path[1:]
- # cauchy_mult makes only little difference as there's only 2-3 div operations
-
- def _supported_ops():
- def aten_fft(inputs, outputs, fftop="fft"):
- from torch.fft import fft, rfft, irfft, fftn, rfftn, irfftn
- inp, num, dim, norm = inputs[0:4]
- is_complex = torch.is_complex(torch.tensor([], dtype=inputs[0].type().dtype()))
- shape = inp.type().sizes()
- torch._C.Value
- if isinstance(dim.type(), torch._C.IntType):
- dim = [dim.toIValue()]
- elif isinstance(dim.type(), torch._C.TupleType):
- dim = [d.toIValue() for d in dim.type().elements()]
- elif isinstance(dim.type(), torch._C.NoneType):
- from torch.fft import fftn, rfftn, irfftn
- assert fftop == "fftn"
- if isinstance(num.type(), torch._C.NoneType):
- dim = list(range(len(shape)))
- elif isinstance(num.type(), torch._C.ListType):
- assert isinstance(num.type().getElementType(), torch._C.IntType)
- num_dims = len(tuple(num.node().inputs()))
- guess_dim = [-1 - i for i in range(num_dims)]
- print(f"Warning, We are not sure about this, guess the fft dim are {guess_dim}. input: {shape}, n or s: {num}, dim: {dim.type()}")
- dim = guess_dim
- else:
- raise NotImplementedError
-
- flops = math.prod(shape) * math.prod([math.log2(shape[i]) for i in dim]) * (4 if is_complex else 1)
- # print(flops, dim, is_complex, shape)
- return flops
- supported_ops={
- "aten::fft_fft": aten_fft,
- "aten::fft_rfft": aten_fft,
- "aten::fft_rfftn": partial(aten_fft, fftop="fftn"),
- "aten::fft_irfft": aten_fft,
- "aten::fft_irfftn": partial(aten_fft, fftop="fftn"),
- }
- return supported_ops
- model = {"ctiny": ctiny, "vitb": vitb}[scale]
- input_shape = (1, 3, size, size)
- inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
- model(inputs[0]) # to force init first
- Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=_supported_ops())
- print("GFlops: ", sum(Gflops.values()), "Params: ", sum([p.numel() for _, p in model.named_parameters()]), flush=True)
- def build_vim_for_throughput(with_ckpt=False, remove_head=False, only_backbone=False, size=224):
- img_size = size
- imgHW = int(math.sqrt(img_size))
- specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
- sys.path.insert(0, specpath)
- import mamba_ssm
- _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
- sys.path = sys.path[1:]
- # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
- kwargs=dict()
- # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
-
- model = _model.VisionMamba(img_size=img_size, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
-
- if only_backbone:
- # copy from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
- # added "return hidden_states, token_position"
- RMSNorm, layer_norm_fn, rms_norm_fn = _model.RMSNorm, _model.layer_norm_fn, _model.rms_norm_fn
- def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
- # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- # with slight modifications to add the dist_token
- x = self.patch_embed(x)
- B, M, _ = x.shape
- if self.if_cls_token:
- if self.use_double_cls_token:
- cls_token_head = self.cls_token_head.expand(B, -1, -1)
- cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
- token_position = [0, M + 1]
- x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
- M = x.shape[1]
- else:
- if self.use_middle_cls_token:
- cls_token = self.cls_token.expand(B, -1, -1)
- token_position = M // 2
- # add cls token in the middle
- x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
- elif if_random_cls_token_position:
- cls_token = self.cls_token.expand(B, -1, -1)
- token_position = random.randint(0, M)
- x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
- print("token_position: ", token_position)
- else:
- cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- token_position = 0
- x = torch.cat((cls_token, x), dim=1)
- M = x.shape[1]
- if self.if_abs_pos_embed:
- # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
- # x = x + self.pos_embed
- # else:
- # pos_embed = interpolate_pos_embed_online(
- # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
- # )
- x = x + self.pos_embed
- x = self.pos_drop(x)
- if if_random_token_rank:
- # 生成随机 shuffle 索引
- shuffle_indices = torch.randperm(M)
- if isinstance(token_position, list):
- print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
- else:
- print("original value: ", x[0, token_position, 0])
- print("original token_position: ", token_position)
- # 执行 shuffle
- x = x[:, shuffle_indices, :]
- if isinstance(token_position, list):
- # 找到 cls token 在 shuffle 之后的新位置
- new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
- token_position = new_token_position
- else:
- # 找到 cls token 在 shuffle 之后的新位置
- token_position = torch.where(shuffle_indices == token_position)[0].item()
- if isinstance(token_position, list):
- print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
- else:
- print("new value: ", x[0, token_position, 0])
- print("new token_position: ", token_position)
- if_flip_img_sequences = False
- if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
- x = x.flip([1])
- if_flip_img_sequences = True
- # mamba impl
- residual = None
- hidden_states = x
- if not self.if_bidirectional:
- for layer in self.layers:
- if if_flip_img_sequences and self.if_rope:
- hidden_states = hidden_states.flip([1])
- if residual is not None:
- residual = residual.flip([1])
- # rope about
- if self.if_rope:
- hidden_states = self.rope(hidden_states)
- if residual is not None and self.if_rope_residual:
- residual = self.rope(residual)
- if if_flip_img_sequences and self.if_rope:
- hidden_states = hidden_states.flip([1])
- if residual is not None:
- residual = residual.flip([1])
- hidden_states, residual = layer(
- hidden_states, residual, inference_params=inference_params
- )
- else:
- # get two layers in a single for-loop
- for i in range(len(self.layers) // 2):
- if self.if_rope:
- hidden_states = self.rope(hidden_states)
- if residual is not None and self.if_rope_residual:
- residual = self.rope(residual)
- hidden_states_f, residual_f = self.layers[i * 2](
- hidden_states, residual, inference_params=inference_params
- )
- hidden_states_b, residual_b = self.layers[i * 2 + 1](
- hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
- )
- hidden_states = hidden_states_f + hidden_states_b.flip([1])
- residual = residual_f + residual_b.flip([1])
- if not self.fused_add_norm:
- if residual is None:
- residual = hidden_states
- else:
- residual = residual + self.drop_path(hidden_states)
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
- else:
- # Set prenorm=False here since we don't need the residual
- fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
- hidden_states = fused_add_norm_fn(
- self.drop_path(hidden_states),
- self.norm_f.weight,
- self.norm_f.bias,
- eps=self.norm_f.eps,
- residual=residual,
- prenorm=False,
- residual_in_fp32=self.residual_in_fp32,
- )
- return hidden_states, token_position
-
- # return only cls token if it exists
- if self.if_cls_token:
- if self.use_double_cls_token:
- return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
- else:
- if self.use_middle_cls_token:
- return hidden_states[:, token_position, :]
- elif if_random_cls_token_position:
- return hidden_states[:, token_position, :]
- else:
- return hidden_states[:, token_position, :]
- if self.final_pool_type == 'none':
- return hidden_states[:, -1, :]
- elif self.final_pool_type == 'mean':
- return hidden_states.mean(dim=1)
- elif self.final_pool_type == 'max':
- return hidden_states
- elif self.final_pool_type == 'all':
- return hidden_states
- else:
- raise NotImplementedError
- # modified from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
- def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
- hs, token_position = forward_features(self, x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
- print("self.if_cls_token", self.if_cls_token, end=" ")
- print("self.use_double_cls_token", self.use_double_cls_token, end=" ")
- print("self.use_middle_cls_token", self.use_middle_cls_token, end=" ")
- print("if_random_cls_token_position", if_random_cls_token_position, end=" ")
- print("if_random_token_rank", if_random_token_rank, end=" ")
- indexes = list(range(hs.shape[1]))
- token_position = token_position if isinstance(token_position, list) else [token_position]
- for t in token_position:
- indexes.remove(t)
- hs = hs[:, indexes, :].contiguous()
- H = int(math.sqrt(hs.shape[1]))
- hs = hs.permute(0, 2, 1).contiguous().view(hs.shape[0], -1, H, H)
- return hs
- model.forward = partial(forward, model)
- elif remove_head:
- model.forward = partial(model.forward, return_features=True)
- model = model.cuda().eval()
-
- if with_ckpt:
- ckpt = torch.load(open(f"{HOME}/packs/ckpts/vim_s_midclstok_80p5acc.pth", "rb"), map_location=torch.device("cpu"))["model"]
-
- # to interplate pos_mebed, the cls_token position must be fixed !
- # otherwise, ignore cls_token and apply interplation to all
- # this checkpoint uses middle cls token
- # from mmpretrain.models.backbones.vision_transformer import resize_pos_embed
- assert not model.use_double_cls_token
- assert model.use_middle_cls_token
- assert ckpt["pos_embed"].shape[1] == 197
- target_token_length = (img_size // 16)**2
- target_token_length_HW = ((img_size // 16), (img_size // 16))
- if target_token_length != 197 - 1:
- mid_token_idx = target_token_length // 2
- cls_token = ckpt["pos_embed"][:, 83:84, :]
- extra_tokens_left = ckpt["pos_embed"][:, :83, :]
- extra_tokens_right = ckpt["pos_embed"][:, 84:, :]
- extra_tokens = torch.cat([extra_tokens_left, extra_tokens_right], dim=1)
- extra_tokens = extra_tokens.reshape(1, 14, 14, -1).permute(0, 3, 1, 2)
- extra_tokens = torch.nn.functional.interpolate(extra_tokens, size=target_token_length_HW, align_corners=False, mode="bicubic")
- extra_tokens = extra_tokens.permute(0, 2, 3, 1).contiguous().view(1, target_token_length, -1)
- pos_embed = torch.cat([extra_tokens[:, :mid_token_idx, :], cls_token, extra_tokens[:, mid_token_idx:, :]], dim=1)
- ckpt["pos_embed"] = pos_embed
-
- model.load_state_dict(ckpt)
- return model
- # 5.301500928 25796584
- def flops_vim(size=224):
- from fvcore.nn import flop_count
- # FLOPs.fvcore_flop_count(BuildModels.build_vmamba(scale="tv2").cuda().eval(), input_shape=(3, size, size), show_arch=False)
- specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
- sys.path.insert(0, specpath)
- import mamba_ssm
- _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
- sys.path = sys.path[1:]
- # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
- kwargs=dict()
- # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
-
- # fused add norm share the same flops as naive one
- model = _model.VisionMamba(img_size=size, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=False, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
- vims = model.cuda().eval()
- # RMSNorm share the same flops as naive one
- # https://github.com/state-spaces/mamba/blob/v1.2.2/mamba_ssm/ops/triton/layernorm.py
- def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
- dtype = x.dtype
- if upcast:
- weight = weight.float()
- bias = bias.float() if bias is not None else None
- if upcast:
- x = x.float()
- residual = residual.float() if residual is not None else residual
- if residual is not None:
- x = (x + residual).to(x.dtype)
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
- out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
- out = out.to(dtype)
- return out if not prenorm else (out, x)
- def rms_forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
- return rms_norm_ref(
- x,
- self.weight,
- self.bias,
- residual=residual,
- eps=self.eps,
- prenorm=prenorm,
- upcast=residual_in_fp32,
- )
- for k, m in vims.named_modules():
- if isinstance(m, _model.RMSNorm):
- m.forward = partial(rms_forward, m)
- input_shape = (1, 3, size, size)
- model = vims.cuda().eval()
- import math
- def causal_conv_1d_jit(inputs, outputs):
- """
- https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
- x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) out: (batch, dim, seqlen)
- out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
- """
- from fvcore.nn.jit_handles import conv_flop_jit
- return conv_flop_jit(inputs, outputs)
- # ONLY FOR VisionMamba
- def MambaInnerFnNoOutProj_jit(inputs, outputs):
- """
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
- x, z = xz.chunk(2, dim=1)
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
- x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
- B = x_dbl[:, delta_rank:delta_rank + d_state]
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
- C = x_dbl[:, -d_state:]
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
- out, scan_intermediates, out_z = selective_scan_cuda.fwd(
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
- )
- """
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A = inputs[0:6]
- Batch, _, L = xz.type().sizes()
- CWidth = conv1d_weight.type().sizes()[-1]
- H = A.type().sizes()[-1] # 16
- Dim, R = delta_proj_weight.type().sizes()
- assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
- assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
- assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
- assert tuple(A.type().sizes()) == (Dim, H)
- with_Z = True
- with_D = False
- if "D" in inputs[6].debugName():
- assert tuple(inputs[6].type().sizes()) == (Dim,)
- with_D = True
- flops = 0
- flops += Batch * (Dim * L) * CWidth # causal_conv1d_cuda.causal_conv1d_fwd
- flops += Batch * (Dim * L) * (R + H + H) # x_dbl = F.linear(...
- flops += Batch * (Dim * R) * (L) # delta_proj_weight @ x_dbl[:, :delta_rank]
-
- # https://github.com/state-spaces/mamba/issues/110
- flops = 9 * Batch * L * Dim * H
- if with_D:
- flops += Batch * Dim * L
- if with_Z:
- flops += Batch * Dim * L
- return flops
- # ONLY FOR Mamba
- def MambaInnerFn_jit(inputs, outputs):
- """
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
- x, z = xz.chunk(2, dim=1)
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
- x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
- B = x_dbl[:, delta_rank:delta_rank + d_state]
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
- C = x_dbl[:, -d_state:]
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
- out, scan_intermediates, out_z = selective_scan_cuda.fwd(
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
- )
- F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
- """
- xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A = inputs[0:8]
- Batch, _, L = xz.type().sizes()
- CWidth = conv1d_weight.type().sizes()[-1]
- H = A.type().sizes()[-1] # 16
- Dim, R = delta_proj_weight.type().sizes()
- assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
- assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
- assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
- assert tuple(A.type().sizes()) == (Dim, H)
- with_Z = True
- with_D = False
- if "D" in inputs[6].debugName():
- assert tuple(inputs[6].type().sizes()) == (Dim,)
- with_D = True
- flops = 0
- flops += Batch * (Dim * L) * CWidth # causal_conv1d_cuda.causal_conv1d_fwd
- flops += Batch * (Dim * L) * (R + H + H) # x_dbl = F.linear(...
- flops += Batch * (Dim * R) * (L) # delta_proj_weight @ x_dbl[:, :delta_rank]
-
- # https://github.com/state-spaces/mamba/issues/110
- flops = 9 * Batch * L * Dim * H
- if with_D:
- flops += Batch * Dim * L
- if with_Z:
- flops += Batch * Dim * L
- out_weight_shape = out_proj_weight.type().sizes()
- assert out_proj_weight[1] == Dim
- flops += Batch * Dim * L * out_proj_weight[0]
- return flops
- supported_ops={
- "aten::gelu": None, # as relu is in _IGNORED_OPS
- "aten::silu": None, # as relu is in _IGNORED_OPS
- "aten::neg": None, # as relu is in _IGNORED_OPS
- "aten::exp": None, # as relu is in _IGNORED_OPS
- "aten::flip": None, # as permute is in _IGNORED_OPS
- "prim::PythonOp.CausalConv1dFn": causal_conv_1d_jit,
- "prim::PythonOp.MambaInnerFnNoOutProj": MambaInnerFnNoOutProj_jit,
- "prim::PythonOp.MambaInnerFn": MambaInnerFn_jit,
- }
- inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
- Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops)
- # print(Gflops.items())
- print("GFlops: ", sum(Gflops.values()), "Params: ", sum([p.numel() for _, p in model.named_parameters()]), flush=True)
- def erf_vim(data_path = "/media/Disk1/Dataset/ImageNet_ILSVRC2012"):
- print("vim ================================", flush=True)
- specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
- sys.path.insert(0, specpath)
- import mamba_ssm
- _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
- sys.path = sys.path[1:]
- # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
- kwargs=dict()
- # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
-
- model = _model.VisionMamba(img_size=1024, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
-
- # copy from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
- # added "return hidden_states, token_position"
- RMSNorm, layer_norm_fn, rms_norm_fn = _model.RMSNorm, _model.layer_norm_fn, _model.rms_norm_fn
- def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
- # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
- # with slight modifications to add the dist_token
- x = self.patch_embed(x)
- B, M, _ = x.shape
- if self.if_cls_token:
- if self.use_double_cls_token:
- cls_token_head = self.cls_token_head.expand(B, -1, -1)
- cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
- token_position = [0, M + 1]
- x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
- M = x.shape[1]
- else:
- if self.use_middle_cls_token:
- cls_token = self.cls_token.expand(B, -1, -1)
- token_position = M // 2
- # add cls token in the middle
- x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
- elif if_random_cls_token_position:
- cls_token = self.cls_token.expand(B, -1, -1)
- token_position = random.randint(0, M)
- x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
- print("token_position: ", token_position)
- else:
- cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- token_position = 0
- x = torch.cat((cls_token, x), dim=1)
- M = x.shape[1]
- if self.if_abs_pos_embed:
- # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
- # x = x + self.pos_embed
- # else:
- # pos_embed = interpolate_pos_embed_online(
- # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
- # )
- x = x + self.pos_embed
- x = self.pos_drop(x)
- if if_random_token_rank:
- # 生成随机 shuffle 索引
- shuffle_indices = torch.randperm(M)
- if isinstance(token_position, list):
- print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
- else:
- print("original value: ", x[0, token_position, 0])
- print("original token_position: ", token_position)
- # 执行 shuffle
- x = x[:, shuffle_indices, :]
- if isinstance(token_position, list):
- # 找到 cls token 在 shuffle 之后的新位置
- new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
- token_position = new_token_position
- else:
- # 找到 cls token 在 shuffle 之后的新位置
- token_position = torch.where(shuffle_indices == token_position)[0].item()
- if isinstance(token_position, list):
- print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
- else:
- print("new value: ", x[0, token_position, 0])
- print("new token_position: ", token_position)
- if_flip_img_sequences = False
- if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
- x = x.flip([1])
- if_flip_img_sequences = True
- # mamba impl
- residual = None
- hidden_states = x
- if not self.if_bidirectional:
- for layer in self.layers:
- if if_flip_img_sequences and self.if_rope:
- hidden_states = hidden_states.flip([1])
- if residual is not None:
- residual = residual.flip([1])
- # rope about
- if self.if_rope:
- hidden_states = self.rope(hidden_states)
- if residual is not None and self.if_rope_residual:
- residual = self.rope(residual)
- if if_flip_img_sequences and self.if_rope:
- hidden_states = hidden_states.flip([1])
- if residual is not None:
- residual = residual.flip([1])
- hidden_states, residual = layer(
- hidden_states, residual, inference_params=inference_params
- )
- else:
- # get two layers in a single for-loop
- for i in range(len(self.layers) // 2):
- if self.if_rope:
- hidden_states = self.rope(hidden_states)
- if residual is not None and self.if_rope_residual:
- residual = self.rope(residual)
- hidden_states_f, residual_f = self.layers[i * 2](
- hidden_states, residual, inference_params=inference_params
- )
- hidden_states_b, residual_b = self.layers[i * 2 + 1](
- hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
- )
- hidden_states = hidden_states_f + hidden_states_b.flip([1])
- residual = residual_f + residual_b.flip([1])
- if not self.fused_add_norm:
- if residual is None:
- residual = hidden_states
- else:
- residual = residual + self.drop_path(hidden_states)
- hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
- else:
- # Set prenorm=False here since we don't need the residual
- fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
- hidden_states = fused_add_norm_fn(
- self.drop_path(hidden_states),
- self.norm_f.weight,
- self.norm_f.bias,
- eps=self.norm_f.eps,
- residual=residual,
- prenorm=False,
- residual_in_fp32=self.residual_in_fp32,
- )
- return hidden_states, token_position
-
- # return only cls token if it exists
- if self.if_cls_token:
- if self.use_double_cls_token:
- return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
- else:
- if self.use_middle_cls_token:
- return hidden_states[:, token_position, :]
- elif if_random_cls_token_position:
- return hidden_states[:, token_position, :]
- else:
- return hidden_states[:, token_position, :]
- if self.final_pool_type == 'none':
- return hidden_states[:, -1, :]
- elif self.final_pool_type == 'mean':
- return hidden_states.mean(dim=1)
- elif self.final_pool_type == 'max':
- return hidden_states
- elif self.final_pool_type == 'all':
- return hidden_states
- else:
- raise NotImplementedError
- # modified from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
- def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
- hs, token_position = forward_features(self, x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
- print("self.if_cls_token", self.if_cls_token, end=" ")
- print("self.use_double_cls_token", self.use_double_cls_token, end=" ")
- print("self.use_middle_cls_token", self.use_middle_cls_token, end=" ")
- print("if_random_cls_token_position", if_random_cls_token_position, end=" ")
- print("if_random_token_rank", if_random_token_rank, end=" ")
- indexes = list(range(hs.shape[1]))
- token_position = token_position if isinstance(token_position, list) else [token_position]
- for t in token_position:
- indexes.remove(t)
- hs = hs[:, indexes, :].contiguous()
- H = int(math.sqrt(hs.shape[1]))
- hs = hs.permute(0, 2, 1).contiguous().view(hs.shape[0], -1, H, H)
- return hs
- model.forward = partial(forward, model)
-
- vims = model.cuda().eval()
- model_before = EffectiveReceiptiveField.get_input_grad_avg(vims, size=1024, data_path=data_path, norms=EffectiveReceiptiveField.simpnorm)
- # with ckpt
- ckpt = torch.load(open(f"{HOME}/packs/ckpts/vim_s_midclstok_80p5acc.pth", "rb"), map_location=torch.device("cpu"))["model"]
-
- # to interplate pos_mebed, the cls_token position must be fixed !
- # otherwise, ignore cls_token and apply interplation to all
- # this checkpoint uses middle cls token
- from mmpretrain.models.backbones.vision_transformer import resize_pos_embed, to_2tuple, np
- assert not vims.use_double_cls_token
- assert vims.use_middle_cls_token
- assert ckpt["pos_embed"].shape[1] == 197
- cls_token = ckpt["pos_embed"][:, 83:84, :]
- extra_tokens_left = ckpt["pos_embed"][:, :83, :]
- extra_tokens_right = ckpt["pos_embed"][:, 84:, :]
- extra_tokens = torch.cat([extra_tokens_left, extra_tokens_right], dim=1)
- extra_tokens = extra_tokens.reshape(1, 14, 14, -1).permute(0, 3, 1, 2)
- extra_tokens = torch.nn.functional.interpolate(extra_tokens, size=(64, 64), align_corners=False, mode="bicubic")
- extra_tokens = extra_tokens.permute(0, 2, 3, 1).contiguous().view(1, 4096, -1)
- pos_embed = torch.cat([extra_tokens[:, :2048, :], cls_token, extra_tokens[:, 2048:, :]], dim=1)
- ckpt["pos_embed"] = pos_embed
-
- model.load_state_dict(ckpt)
- vims = model.cuda().eval()
- model_after = EffectiveReceiptiveField.get_input_grad_avg(vims, size=1024, data_path=data_path, norms=EffectiveReceiptiveField.simpnorm)
- return model_before, model_after
- if __name__ == "__main__":
- showpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./show").rstrip("/")
- data_path = "/media/Disk1/Dataset/ImageNet_ILSVRC2012"
-
- ExtraDev.flops_vim()
- ExtraDev.flops_s4nd()
- vim_before, vim_after = ExtraDev.erf_vim()
- visualize.visualize_snsmaps([(vim_before, ""), (vim_after, "")], savefig=f"{showpath}/erf_s4ndmethods.jpg", rows=2, sticks=False, figsize=(10, 10.75), cmap='RdYlGn')
|