analyze_for_vim.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. import os
  2. import sys
  3. import torch
  4. import random
  5. import math
  6. from functools import partial
  7. from utils import import_abspy, EffectiveReceiptiveField, visualize
  8. HOME = os.environ["HOME"].rstrip("/")
  9. class ExtraDev:
  10. # 5.162112298177406 30007057
  11. # 17.069485400571516 91157224
  12. def flops_s4nd(size=224, scale="ctiny"):
  13. import math
  14. from fvcore.nn import flop_count
  15. specpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./convnexts4nd")
  16. sys.path.insert(0, specpath)
  17. import timm; assert timm.__version__ == "0.5.4"
  18. import structured_kernels
  19. model1 = import_abspy("vit_all", f"{os.path.dirname(__file__)}/convnexts4nd")
  20. vitb = model1.vit_base_s4nd
  21. vitb = vitb().cuda().eval()
  22. model2 = import_abspy("convnext_timm", f"{os.path.dirname(__file__)}/convnexts4nd")
  23. ctiny = model2.convnext_tiny_s4nd
  24. ctiny = ctiny().cuda().eval()
  25. sys.path = sys.path[1:]
  26. # cauchy_mult makes only little difference as there's only 2-3 div operations
  27. def _supported_ops():
  28. def aten_fft(inputs, outputs, fftop="fft"):
  29. from torch.fft import fft, rfft, irfft, fftn, rfftn, irfftn
  30. inp, num, dim, norm = inputs[0:4]
  31. is_complex = torch.is_complex(torch.tensor([], dtype=inputs[0].type().dtype()))
  32. shape = inp.type().sizes()
  33. torch._C.Value
  34. if isinstance(dim.type(), torch._C.IntType):
  35. dim = [dim.toIValue()]
  36. elif isinstance(dim.type(), torch._C.TupleType):
  37. dim = [d.toIValue() for d in dim.type().elements()]
  38. elif isinstance(dim.type(), torch._C.NoneType):
  39. from torch.fft import fftn, rfftn, irfftn
  40. assert fftop == "fftn"
  41. if isinstance(num.type(), torch._C.NoneType):
  42. dim = list(range(len(shape)))
  43. elif isinstance(num.type(), torch._C.ListType):
  44. assert isinstance(num.type().getElementType(), torch._C.IntType)
  45. num_dims = len(tuple(num.node().inputs()))
  46. guess_dim = [-1 - i for i in range(num_dims)]
  47. print(f"Warning, We are not sure about this, guess the fft dim are {guess_dim}. input: {shape}, n or s: {num}, dim: {dim.type()}")
  48. dim = guess_dim
  49. else:
  50. raise NotImplementedError
  51. flops = math.prod(shape) * math.prod([math.log2(shape[i]) for i in dim]) * (4 if is_complex else 1)
  52. # print(flops, dim, is_complex, shape)
  53. return flops
  54. supported_ops={
  55. "aten::fft_fft": aten_fft,
  56. "aten::fft_rfft": aten_fft,
  57. "aten::fft_rfftn": partial(aten_fft, fftop="fftn"),
  58. "aten::fft_irfft": aten_fft,
  59. "aten::fft_irfftn": partial(aten_fft, fftop="fftn"),
  60. }
  61. return supported_ops
  62. model = {"ctiny": ctiny, "vitb": vitb}[scale]
  63. input_shape = (1, 3, size, size)
  64. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  65. model(inputs[0]) # to force init first
  66. Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=_supported_ops())
  67. print("GFlops: ", sum(Gflops.values()), "Params: ", sum([p.numel() for _, p in model.named_parameters()]), flush=True)
  68. def build_vim_for_throughput(with_ckpt=False, remove_head=False, only_backbone=False, size=224):
  69. img_size = size
  70. imgHW = int(math.sqrt(img_size))
  71. specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
  72. sys.path.insert(0, specpath)
  73. import mamba_ssm
  74. _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
  75. sys.path = sys.path[1:]
  76. # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
  77. kwargs=dict()
  78. # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  79. model = _model.VisionMamba(img_size=img_size, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  80. if only_backbone:
  81. # copy from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
  82. # added "return hidden_states, token_position"
  83. RMSNorm, layer_norm_fn, rms_norm_fn = _model.RMSNorm, _model.layer_norm_fn, _model.rms_norm_fn
  84. def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  85. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  86. # with slight modifications to add the dist_token
  87. x = self.patch_embed(x)
  88. B, M, _ = x.shape
  89. if self.if_cls_token:
  90. if self.use_double_cls_token:
  91. cls_token_head = self.cls_token_head.expand(B, -1, -1)
  92. cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
  93. token_position = [0, M + 1]
  94. x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
  95. M = x.shape[1]
  96. else:
  97. if self.use_middle_cls_token:
  98. cls_token = self.cls_token.expand(B, -1, -1)
  99. token_position = M // 2
  100. # add cls token in the middle
  101. x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
  102. elif if_random_cls_token_position:
  103. cls_token = self.cls_token.expand(B, -1, -1)
  104. token_position = random.randint(0, M)
  105. x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
  106. print("token_position: ", token_position)
  107. else:
  108. cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  109. token_position = 0
  110. x = torch.cat((cls_token, x), dim=1)
  111. M = x.shape[1]
  112. if self.if_abs_pos_embed:
  113. # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
  114. # x = x + self.pos_embed
  115. # else:
  116. # pos_embed = interpolate_pos_embed_online(
  117. # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
  118. # )
  119. x = x + self.pos_embed
  120. x = self.pos_drop(x)
  121. if if_random_token_rank:
  122. # 生成随机 shuffle 索引
  123. shuffle_indices = torch.randperm(M)
  124. if isinstance(token_position, list):
  125. print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
  126. else:
  127. print("original value: ", x[0, token_position, 0])
  128. print("original token_position: ", token_position)
  129. # 执行 shuffle
  130. x = x[:, shuffle_indices, :]
  131. if isinstance(token_position, list):
  132. # 找到 cls token 在 shuffle 之后的新位置
  133. new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
  134. token_position = new_token_position
  135. else:
  136. # 找到 cls token 在 shuffle 之后的新位置
  137. token_position = torch.where(shuffle_indices == token_position)[0].item()
  138. if isinstance(token_position, list):
  139. print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
  140. else:
  141. print("new value: ", x[0, token_position, 0])
  142. print("new token_position: ", token_position)
  143. if_flip_img_sequences = False
  144. if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
  145. x = x.flip([1])
  146. if_flip_img_sequences = True
  147. # mamba impl
  148. residual = None
  149. hidden_states = x
  150. if not self.if_bidirectional:
  151. for layer in self.layers:
  152. if if_flip_img_sequences and self.if_rope:
  153. hidden_states = hidden_states.flip([1])
  154. if residual is not None:
  155. residual = residual.flip([1])
  156. # rope about
  157. if self.if_rope:
  158. hidden_states = self.rope(hidden_states)
  159. if residual is not None and self.if_rope_residual:
  160. residual = self.rope(residual)
  161. if if_flip_img_sequences and self.if_rope:
  162. hidden_states = hidden_states.flip([1])
  163. if residual is not None:
  164. residual = residual.flip([1])
  165. hidden_states, residual = layer(
  166. hidden_states, residual, inference_params=inference_params
  167. )
  168. else:
  169. # get two layers in a single for-loop
  170. for i in range(len(self.layers) // 2):
  171. if self.if_rope:
  172. hidden_states = self.rope(hidden_states)
  173. if residual is not None and self.if_rope_residual:
  174. residual = self.rope(residual)
  175. hidden_states_f, residual_f = self.layers[i * 2](
  176. hidden_states, residual, inference_params=inference_params
  177. )
  178. hidden_states_b, residual_b = self.layers[i * 2 + 1](
  179. hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
  180. )
  181. hidden_states = hidden_states_f + hidden_states_b.flip([1])
  182. residual = residual_f + residual_b.flip([1])
  183. if not self.fused_add_norm:
  184. if residual is None:
  185. residual = hidden_states
  186. else:
  187. residual = residual + self.drop_path(hidden_states)
  188. hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
  189. else:
  190. # Set prenorm=False here since we don't need the residual
  191. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
  192. hidden_states = fused_add_norm_fn(
  193. self.drop_path(hidden_states),
  194. self.norm_f.weight,
  195. self.norm_f.bias,
  196. eps=self.norm_f.eps,
  197. residual=residual,
  198. prenorm=False,
  199. residual_in_fp32=self.residual_in_fp32,
  200. )
  201. return hidden_states, token_position
  202. # return only cls token if it exists
  203. if self.if_cls_token:
  204. if self.use_double_cls_token:
  205. return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
  206. else:
  207. if self.use_middle_cls_token:
  208. return hidden_states[:, token_position, :]
  209. elif if_random_cls_token_position:
  210. return hidden_states[:, token_position, :]
  211. else:
  212. return hidden_states[:, token_position, :]
  213. if self.final_pool_type == 'none':
  214. return hidden_states[:, -1, :]
  215. elif self.final_pool_type == 'mean':
  216. return hidden_states.mean(dim=1)
  217. elif self.final_pool_type == 'max':
  218. return hidden_states
  219. elif self.final_pool_type == 'all':
  220. return hidden_states
  221. else:
  222. raise NotImplementedError
  223. # modified from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
  224. def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  225. hs, token_position = forward_features(self, x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
  226. print("self.if_cls_token", self.if_cls_token, end=" ")
  227. print("self.use_double_cls_token", self.use_double_cls_token, end=" ")
  228. print("self.use_middle_cls_token", self.use_middle_cls_token, end=" ")
  229. print("if_random_cls_token_position", if_random_cls_token_position, end=" ")
  230. print("if_random_token_rank", if_random_token_rank, end=" ")
  231. indexes = list(range(hs.shape[1]))
  232. token_position = token_position if isinstance(token_position, list) else [token_position]
  233. for t in token_position:
  234. indexes.remove(t)
  235. hs = hs[:, indexes, :].contiguous()
  236. H = int(math.sqrt(hs.shape[1]))
  237. hs = hs.permute(0, 2, 1).contiguous().view(hs.shape[0], -1, H, H)
  238. return hs
  239. model.forward = partial(forward, model)
  240. elif remove_head:
  241. model.forward = partial(model.forward, return_features=True)
  242. model = model.cuda().eval()
  243. if with_ckpt:
  244. ckpt = torch.load(open(f"{HOME}/packs/ckpts/vim_s_midclstok_80p5acc.pth", "rb"), map_location=torch.device("cpu"))["model"]
  245. # to interplate pos_mebed, the cls_token position must be fixed !
  246. # otherwise, ignore cls_token and apply interplation to all
  247. # this checkpoint uses middle cls token
  248. # from mmpretrain.models.backbones.vision_transformer import resize_pos_embed
  249. assert not model.use_double_cls_token
  250. assert model.use_middle_cls_token
  251. assert ckpt["pos_embed"].shape[1] == 197
  252. target_token_length = (img_size // 16)**2
  253. target_token_length_HW = ((img_size // 16), (img_size // 16))
  254. if target_token_length != 197 - 1:
  255. mid_token_idx = target_token_length // 2
  256. cls_token = ckpt["pos_embed"][:, 83:84, :]
  257. extra_tokens_left = ckpt["pos_embed"][:, :83, :]
  258. extra_tokens_right = ckpt["pos_embed"][:, 84:, :]
  259. extra_tokens = torch.cat([extra_tokens_left, extra_tokens_right], dim=1)
  260. extra_tokens = extra_tokens.reshape(1, 14, 14, -1).permute(0, 3, 1, 2)
  261. extra_tokens = torch.nn.functional.interpolate(extra_tokens, size=target_token_length_HW, align_corners=False, mode="bicubic")
  262. extra_tokens = extra_tokens.permute(0, 2, 3, 1).contiguous().view(1, target_token_length, -1)
  263. pos_embed = torch.cat([extra_tokens[:, :mid_token_idx, :], cls_token, extra_tokens[:, mid_token_idx:, :]], dim=1)
  264. ckpt["pos_embed"] = pos_embed
  265. model.load_state_dict(ckpt)
  266. return model
  267. # 5.301500928 25796584
  268. def flops_vim(size=224):
  269. from fvcore.nn import flop_count
  270. # FLOPs.fvcore_flop_count(BuildModels.build_vmamba(scale="tv2").cuda().eval(), input_shape=(3, size, size), show_arch=False)
  271. specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
  272. sys.path.insert(0, specpath)
  273. import mamba_ssm
  274. _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
  275. sys.path = sys.path[1:]
  276. # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
  277. kwargs=dict()
  278. # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  279. # fused add norm share the same flops as naive one
  280. model = _model.VisionMamba(img_size=size, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=False, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  281. vims = model.cuda().eval()
  282. # RMSNorm share the same flops as naive one
  283. # https://github.com/state-spaces/mamba/blob/v1.2.2/mamba_ssm/ops/triton/layernorm.py
  284. def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False):
  285. dtype = x.dtype
  286. if upcast:
  287. weight = weight.float()
  288. bias = bias.float() if bias is not None else None
  289. if upcast:
  290. x = x.float()
  291. residual = residual.float() if residual is not None else residual
  292. if residual is not None:
  293. x = (x + residual).to(x.dtype)
  294. rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
  295. out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
  296. out = out.to(dtype)
  297. return out if not prenorm else (out, x)
  298. def rms_forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
  299. return rms_norm_ref(
  300. x,
  301. self.weight,
  302. self.bias,
  303. residual=residual,
  304. eps=self.eps,
  305. prenorm=prenorm,
  306. upcast=residual_in_fp32,
  307. )
  308. for k, m in vims.named_modules():
  309. if isinstance(m, _model.RMSNorm):
  310. m.forward = partial(rms_forward, m)
  311. input_shape = (1, 3, size, size)
  312. model = vims.cuda().eval()
  313. import math
  314. def causal_conv_1d_jit(inputs, outputs):
  315. """
  316. https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
  317. x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) out: (batch, dim, seqlen)
  318. out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
  319. """
  320. from fvcore.nn.jit_handles import conv_flop_jit
  321. return conv_flop_jit(inputs, outputs)
  322. # ONLY FOR VisionMamba
  323. def MambaInnerFnNoOutProj_jit(inputs, outputs):
  324. """
  325. conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
  326. x, z = xz.chunk(2, dim=1)
  327. conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
  328. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
  329. delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
  330. B = x_dbl[:, delta_rank:delta_rank + d_state]
  331. B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  332. C = x_dbl[:, -d_state:]
  333. C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  334. out, scan_intermediates, out_z = selective_scan_cuda.fwd(
  335. conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
  336. )
  337. """
  338. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, A = inputs[0:6]
  339. Batch, _, L = xz.type().sizes()
  340. CWidth = conv1d_weight.type().sizes()[-1]
  341. H = A.type().sizes()[-1] # 16
  342. Dim, R = delta_proj_weight.type().sizes()
  343. assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
  344. assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
  345. assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
  346. assert tuple(A.type().sizes()) == (Dim, H)
  347. with_Z = True
  348. with_D = False
  349. if "D" in inputs[6].debugName():
  350. assert tuple(inputs[6].type().sizes()) == (Dim,)
  351. with_D = True
  352. flops = 0
  353. flops += Batch * (Dim * L) * CWidth # causal_conv1d_cuda.causal_conv1d_fwd
  354. flops += Batch * (Dim * L) * (R + H + H) # x_dbl = F.linear(...
  355. flops += Batch * (Dim * R) * (L) # delta_proj_weight @ x_dbl[:, :delta_rank]
  356. # https://github.com/state-spaces/mamba/issues/110
  357. flops = 9 * Batch * L * Dim * H
  358. if with_D:
  359. flops += Batch * Dim * L
  360. if with_Z:
  361. flops += Batch * Dim * L
  362. return flops
  363. # ONLY FOR Mamba
  364. def MambaInnerFn_jit(inputs, outputs):
  365. """
  366. conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
  367. x, z = xz.chunk(2, dim=1)
  368. conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
  369. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
  370. delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
  371. B = x_dbl[:, delta_rank:delta_rank + d_state]
  372. B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  373. C = x_dbl[:, -d_state:]
  374. C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  375. out, scan_intermediates, out_z = selective_scan_cuda.fwd(
  376. conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
  377. )
  378. F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
  379. """
  380. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, out_proj_bias, A = inputs[0:8]
  381. Batch, _, L = xz.type().sizes()
  382. CWidth = conv1d_weight.type().sizes()[-1]
  383. H = A.type().sizes()[-1] # 16
  384. Dim, R = delta_proj_weight.type().sizes()
  385. assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
  386. assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
  387. assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
  388. assert tuple(A.type().sizes()) == (Dim, H)
  389. with_Z = True
  390. with_D = False
  391. if "D" in inputs[6].debugName():
  392. assert tuple(inputs[6].type().sizes()) == (Dim,)
  393. with_D = True
  394. flops = 0
  395. flops += Batch * (Dim * L) * CWidth # causal_conv1d_cuda.causal_conv1d_fwd
  396. flops += Batch * (Dim * L) * (R + H + H) # x_dbl = F.linear(...
  397. flops += Batch * (Dim * R) * (L) # delta_proj_weight @ x_dbl[:, :delta_rank]
  398. # https://github.com/state-spaces/mamba/issues/110
  399. flops = 9 * Batch * L * Dim * H
  400. if with_D:
  401. flops += Batch * Dim * L
  402. if with_Z:
  403. flops += Batch * Dim * L
  404. out_weight_shape = out_proj_weight.type().sizes()
  405. assert out_proj_weight[1] == Dim
  406. flops += Batch * Dim * L * out_proj_weight[0]
  407. return flops
  408. supported_ops={
  409. "aten::gelu": None, # as relu is in _IGNORED_OPS
  410. "aten::silu": None, # as relu is in _IGNORED_OPS
  411. "aten::neg": None, # as relu is in _IGNORED_OPS
  412. "aten::exp": None, # as relu is in _IGNORED_OPS
  413. "aten::flip": None, # as permute is in _IGNORED_OPS
  414. "prim::PythonOp.CausalConv1dFn": causal_conv_1d_jit,
  415. "prim::PythonOp.MambaInnerFnNoOutProj": MambaInnerFnNoOutProj_jit,
  416. "prim::PythonOp.MambaInnerFn": MambaInnerFn_jit,
  417. }
  418. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  419. Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops)
  420. # print(Gflops.items())
  421. print("GFlops: ", sum(Gflops.values()), "Params: ", sum([p.numel() for _, p in model.named_parameters()]), flush=True)
  422. def erf_vim(data_path = "/media/Disk1/Dataset/ImageNet_ILSVRC2012"):
  423. print("vim ================================", flush=True)
  424. specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
  425. sys.path.insert(0, specpath)
  426. import mamba_ssm
  427. _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
  428. sys.path = sys.path[1:]
  429. # model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
  430. kwargs=dict()
  431. # model = _model.VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  432. model = _model.VisionMamba(img_size=1024, patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  433. # copy from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
  434. # added "return hidden_states, token_position"
  435. RMSNorm, layer_norm_fn, rms_norm_fn = _model.RMSNorm, _model.layer_norm_fn, _model.rms_norm_fn
  436. def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  437. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  438. # with slight modifications to add the dist_token
  439. x = self.patch_embed(x)
  440. B, M, _ = x.shape
  441. if self.if_cls_token:
  442. if self.use_double_cls_token:
  443. cls_token_head = self.cls_token_head.expand(B, -1, -1)
  444. cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
  445. token_position = [0, M + 1]
  446. x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
  447. M = x.shape[1]
  448. else:
  449. if self.use_middle_cls_token:
  450. cls_token = self.cls_token.expand(B, -1, -1)
  451. token_position = M // 2
  452. # add cls token in the middle
  453. x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
  454. elif if_random_cls_token_position:
  455. cls_token = self.cls_token.expand(B, -1, -1)
  456. token_position = random.randint(0, M)
  457. x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
  458. print("token_position: ", token_position)
  459. else:
  460. cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
  461. token_position = 0
  462. x = torch.cat((cls_token, x), dim=1)
  463. M = x.shape[1]
  464. if self.if_abs_pos_embed:
  465. # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
  466. # x = x + self.pos_embed
  467. # else:
  468. # pos_embed = interpolate_pos_embed_online(
  469. # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
  470. # )
  471. x = x + self.pos_embed
  472. x = self.pos_drop(x)
  473. if if_random_token_rank:
  474. # 生成随机 shuffle 索引
  475. shuffle_indices = torch.randperm(M)
  476. if isinstance(token_position, list):
  477. print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
  478. else:
  479. print("original value: ", x[0, token_position, 0])
  480. print("original token_position: ", token_position)
  481. # 执行 shuffle
  482. x = x[:, shuffle_indices, :]
  483. if isinstance(token_position, list):
  484. # 找到 cls token 在 shuffle 之后的新位置
  485. new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
  486. token_position = new_token_position
  487. else:
  488. # 找到 cls token 在 shuffle 之后的新位置
  489. token_position = torch.where(shuffle_indices == token_position)[0].item()
  490. if isinstance(token_position, list):
  491. print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
  492. else:
  493. print("new value: ", x[0, token_position, 0])
  494. print("new token_position: ", token_position)
  495. if_flip_img_sequences = False
  496. if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
  497. x = x.flip([1])
  498. if_flip_img_sequences = True
  499. # mamba impl
  500. residual = None
  501. hidden_states = x
  502. if not self.if_bidirectional:
  503. for layer in self.layers:
  504. if if_flip_img_sequences and self.if_rope:
  505. hidden_states = hidden_states.flip([1])
  506. if residual is not None:
  507. residual = residual.flip([1])
  508. # rope about
  509. if self.if_rope:
  510. hidden_states = self.rope(hidden_states)
  511. if residual is not None and self.if_rope_residual:
  512. residual = self.rope(residual)
  513. if if_flip_img_sequences and self.if_rope:
  514. hidden_states = hidden_states.flip([1])
  515. if residual is not None:
  516. residual = residual.flip([1])
  517. hidden_states, residual = layer(
  518. hidden_states, residual, inference_params=inference_params
  519. )
  520. else:
  521. # get two layers in a single for-loop
  522. for i in range(len(self.layers) // 2):
  523. if self.if_rope:
  524. hidden_states = self.rope(hidden_states)
  525. if residual is not None and self.if_rope_residual:
  526. residual = self.rope(residual)
  527. hidden_states_f, residual_f = self.layers[i * 2](
  528. hidden_states, residual, inference_params=inference_params
  529. )
  530. hidden_states_b, residual_b = self.layers[i * 2 + 1](
  531. hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
  532. )
  533. hidden_states = hidden_states_f + hidden_states_b.flip([1])
  534. residual = residual_f + residual_b.flip([1])
  535. if not self.fused_add_norm:
  536. if residual is None:
  537. residual = hidden_states
  538. else:
  539. residual = residual + self.drop_path(hidden_states)
  540. hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
  541. else:
  542. # Set prenorm=False here since we don't need the residual
  543. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
  544. hidden_states = fused_add_norm_fn(
  545. self.drop_path(hidden_states),
  546. self.norm_f.weight,
  547. self.norm_f.bias,
  548. eps=self.norm_f.eps,
  549. residual=residual,
  550. prenorm=False,
  551. residual_in_fp32=self.residual_in_fp32,
  552. )
  553. return hidden_states, token_position
  554. # return only cls token if it exists
  555. if self.if_cls_token:
  556. if self.use_double_cls_token:
  557. return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
  558. else:
  559. if self.use_middle_cls_token:
  560. return hidden_states[:, token_position, :]
  561. elif if_random_cls_token_position:
  562. return hidden_states[:, token_position, :]
  563. else:
  564. return hidden_states[:, token_position, :]
  565. if self.final_pool_type == 'none':
  566. return hidden_states[:, -1, :]
  567. elif self.final_pool_type == 'mean':
  568. return hidden_states.mean(dim=1)
  569. elif self.final_pool_type == 'max':
  570. return hidden_states
  571. elif self.final_pool_type == 'all':
  572. return hidden_states
  573. else:
  574. raise NotImplementedError
  575. # modified from https://github.com/hustvl/Vim/blob/main/vim/models_mamba.py#VisionMamba
  576. def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  577. hs, token_position = forward_features(self, x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
  578. print("self.if_cls_token", self.if_cls_token, end=" ")
  579. print("self.use_double_cls_token", self.use_double_cls_token, end=" ")
  580. print("self.use_middle_cls_token", self.use_middle_cls_token, end=" ")
  581. print("if_random_cls_token_position", if_random_cls_token_position, end=" ")
  582. print("if_random_token_rank", if_random_token_rank, end=" ")
  583. indexes = list(range(hs.shape[1]))
  584. token_position = token_position if isinstance(token_position, list) else [token_position]
  585. for t in token_position:
  586. indexes.remove(t)
  587. hs = hs[:, indexes, :].contiguous()
  588. H = int(math.sqrt(hs.shape[1]))
  589. hs = hs.permute(0, 2, 1).contiguous().view(hs.shape[0], -1, H, H)
  590. return hs
  591. model.forward = partial(forward, model)
  592. vims = model.cuda().eval()
  593. model_before = EffectiveReceiptiveField.get_input_grad_avg(vims, size=1024, data_path=data_path, norms=EffectiveReceiptiveField.simpnorm)
  594. # with ckpt
  595. ckpt = torch.load(open(f"{HOME}/packs/ckpts/vim_s_midclstok_80p5acc.pth", "rb"), map_location=torch.device("cpu"))["model"]
  596. # to interplate pos_mebed, the cls_token position must be fixed !
  597. # otherwise, ignore cls_token and apply interplation to all
  598. # this checkpoint uses middle cls token
  599. from mmpretrain.models.backbones.vision_transformer import resize_pos_embed, to_2tuple, np
  600. assert not vims.use_double_cls_token
  601. assert vims.use_middle_cls_token
  602. assert ckpt["pos_embed"].shape[1] == 197
  603. cls_token = ckpt["pos_embed"][:, 83:84, :]
  604. extra_tokens_left = ckpt["pos_embed"][:, :83, :]
  605. extra_tokens_right = ckpt["pos_embed"][:, 84:, :]
  606. extra_tokens = torch.cat([extra_tokens_left, extra_tokens_right], dim=1)
  607. extra_tokens = extra_tokens.reshape(1, 14, 14, -1).permute(0, 3, 1, 2)
  608. extra_tokens = torch.nn.functional.interpolate(extra_tokens, size=(64, 64), align_corners=False, mode="bicubic")
  609. extra_tokens = extra_tokens.permute(0, 2, 3, 1).contiguous().view(1, 4096, -1)
  610. pos_embed = torch.cat([extra_tokens[:, :2048, :], cls_token, extra_tokens[:, 2048:, :]], dim=1)
  611. ckpt["pos_embed"] = pos_embed
  612. model.load_state_dict(ckpt)
  613. vims = model.cuda().eval()
  614. model_after = EffectiveReceiptiveField.get_input_grad_avg(vims, size=1024, data_path=data_path, norms=EffectiveReceiptiveField.simpnorm)
  615. return model_before, model_after
  616. if __name__ == "__main__":
  617. showpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./show").rstrip("/")
  618. data_path = "/media/Disk1/Dataset/ImageNet_ILSVRC2012"
  619. ExtraDev.flops_vim()
  620. ExtraDev.flops_s4nd()
  621. vim_before, vim_after = ExtraDev.erf_vim()
  622. visualize.visualize_snsmaps([(vim_before, ""), (vim_after, "")], savefig=f"{showpath}/erf_s4ndmethods.jpg", rows=2, sticks=False, figsize=(10, 10.75), cmap='RdYlGn')