mobilemamba.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. import torch
  2. import itertools
  3. import torch.nn as nn
  4. from timm.models.vision_transformer import trunc_normal_
  5. from timm.models.layers import SqueezeExcite
  6. from model import MODEL
  7. from model.lib_mamba.vmambanew import SS2D
  8. import torch.nn.functional as F
  9. from functools import partial
  10. import pywt
  11. import pywt.data
  12. from timm.layers import DropPath
  13. def create_wavelet_filter(wave, in_size, out_size, type=torch.float):
  14. w = pywt.Wavelet(wave)
  15. dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
  16. dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
  17. dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
  18. dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
  19. dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
  20. dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)
  21. dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1)
  22. rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0])
  23. rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0])
  24. rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1),
  25. rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1),
  26. rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1),
  27. rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0)
  28. rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1)
  29. return dec_filters, rec_filters
  30. def wavelet_transform(x, filters):
  31. b, c, h, w = x.shape
  32. pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
  33. x = F.conv2d(x, filters, stride=2, groups=c, padding=pad)
  34. x = x.reshape(b, c, 4, h // 2, w // 2)
  35. return x
  36. def inverse_wavelet_transform(x, filters):
  37. b, c, _, h_half, w_half = x.shape
  38. pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1)
  39. x = x.reshape(b, c * 4, h_half, w_half)
  40. x = F.conv_transpose2d(x, filters, stride=2, groups=c, padding=pad)
  41. return x
  42. class MBWTConv2d(nn.Module):
  43. def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, bias=True, wt_levels=1, wt_type='db1',
  44. ssm_ratio=1, forward_type="v05", ):
  45. super(MBWTConv2d, self).__init__()
  46. assert in_channels == out_channels
  47. self.in_channels = in_channels
  48. self.wt_levels = wt_levels
  49. self.stride = stride
  50. self.dilation = 1
  51. self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
  52. self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
  53. self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
  54. self.wt_function = partial(wavelet_transform, filters=self.wt_filter)
  55. self.iwt_function = partial(inverse_wavelet_transform, filters=self.iwt_filter)
  56. self.global_atten = SS2D(d_model=in_channels, d_state=1,
  57. ssm_ratio=ssm_ratio, initialize="v2", forward_type=forward_type, channel_first=True,
  58. k_group=2)
  59. self.base_scale = _ScaleModule([1, in_channels, 1, 1])
  60. self.wavelet_convs = nn.ModuleList(
  61. [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
  62. groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
  63. )
  64. self.wavelet_scale = nn.ModuleList(
  65. [_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
  66. )
  67. if self.stride > 1:
  68. self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
  69. self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
  70. groups=in_channels)
  71. else:
  72. self.do_stride = None
  73. def forward(self, x):
  74. x_ll_in_levels = []
  75. x_h_in_levels = []
  76. shapes_in_levels = []
  77. curr_x_ll = x
  78. for i in range(self.wt_levels):
  79. curr_shape = curr_x_ll.shape
  80. shapes_in_levels.append(curr_shape)
  81. if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
  82. curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
  83. curr_x_ll = F.pad(curr_x_ll, curr_pads)
  84. curr_x = self.wt_function(curr_x_ll)
  85. curr_x_ll = curr_x[:, :, 0, :, :]
  86. shape_x = curr_x.shape
  87. curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
  88. curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
  89. curr_x_tag = curr_x_tag.reshape(shape_x)
  90. x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
  91. x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
  92. next_x_ll = 0
  93. for i in range(self.wt_levels - 1, -1, -1):
  94. curr_x_ll = x_ll_in_levels.pop()
  95. curr_x_h = x_h_in_levels.pop()
  96. curr_shape = shapes_in_levels.pop()
  97. curr_x_ll = curr_x_ll + next_x_ll
  98. curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
  99. next_x_ll = self.iwt_function(curr_x)
  100. next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
  101. x_tag = next_x_ll
  102. assert len(x_ll_in_levels) == 0
  103. x = self.base_scale(self.global_atten(x))
  104. x = x + x_tag
  105. if self.do_stride is not None:
  106. x = self.do_stride(x)
  107. return x
  108. class _ScaleModule(nn.Module):
  109. def __init__(self, dims, init_scale=1.0, init_bias=0):
  110. super(_ScaleModule, self).__init__()
  111. self.dims = dims
  112. self.weight = nn.Parameter(torch.ones(*dims) * init_scale)
  113. self.bias = None
  114. def forward(self, x):
  115. return torch.mul(self.weight, x)
  116. class DWConv2d_BN_ReLU(nn.Sequential):
  117. def __init__(self, in_channels, out_channels, kernel_size=3, bn_weight_init=1):
  118. super().__init__()
  119. self.add_module('dwconv3x3',
  120. nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
  121. groups=in_channels,
  122. bias=False))
  123. self.add_module('bn1', nn.BatchNorm2d(in_channels))
  124. self.add_module('relu', nn.ReLU(inplace=True))
  125. self.add_module('dwconv1x1',
  126. nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=in_channels,
  127. bias=False))
  128. self.add_module('bn2', nn.BatchNorm2d(out_channels))
  129. # Initialize batch norm weights
  130. nn.init.constant_(self.bn1.weight, bn_weight_init)
  131. nn.init.constant_(self.bn1.bias, 0)
  132. nn.init.constant_(self.bn2.weight, bn_weight_init)
  133. nn.init.constant_(self.bn2.bias, 0)
  134. @torch.no_grad()
  135. def fuse(self):
  136. # Fuse dwconv3x3 and bn1
  137. dwconv3x3, bn1, relu, dwconv1x1, bn2 = self._modules.values()
  138. w1 = bn1.weight / (bn1.running_var + bn1.eps) ** 0.5
  139. w1 = dwconv3x3.weight * w1[:, None, None, None]
  140. b1 = bn1.bias - bn1.running_mean * bn1.weight / (bn1.running_var + bn1.eps) ** 0.5
  141. fused_dwconv3x3 = nn.Conv2d(w1.size(1) * dwconv3x3.groups, w1.size(0), w1.shape[2:], stride=dwconv3x3.stride,
  142. padding=dwconv3x3.padding, dilation=dwconv3x3.dilation, groups=dwconv3x3.groups,
  143. device=dwconv3x3.weight.device)
  144. fused_dwconv3x3.weight.data.copy_(w1)
  145. fused_dwconv3x3.bias.data.copy_(b1)
  146. # Fuse dwconv1x1 and bn2
  147. w2 = bn2.weight / (bn2.running_var + bn2.eps) ** 0.5
  148. w2 = dwconv1x1.weight * w2[:, None, None, None]
  149. b2 = bn2.bias - bn2.running_mean * bn2.weight / (bn2.running_var + bn2.eps) ** 0.5
  150. fused_dwconv1x1 = nn.Conv2d(w2.size(1) * dwconv1x1.groups, w2.size(0), w2.shape[2:], stride=dwconv1x1.stride,
  151. padding=dwconv1x1.padding, dilation=dwconv1x1.dilation, groups=dwconv1x1.groups,
  152. device=dwconv1x1.weight.device)
  153. fused_dwconv1x1.weight.data.copy_(w2)
  154. fused_dwconv1x1.bias.data.copy_(b2)
  155. # Create a new sequential model with fused layers
  156. fused_model = nn.Sequential(fused_dwconv3x3, relu, fused_dwconv1x1)
  157. return fused_model
  158. class Conv2d_BN(torch.nn.Sequential):
  159. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  160. groups=1, bn_weight_init=1, ):
  161. super().__init__()
  162. self.add_module('c', torch.nn.Conv2d(
  163. a, b, ks, stride, pad, dilation, groups, bias=False))
  164. self.add_module('bn', torch.nn.BatchNorm2d(b))
  165. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  166. torch.nn.init.constant_(self.bn.bias, 0)
  167. @torch.no_grad()
  168. def fuse(self):
  169. c, bn = self._modules.values()
  170. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  171. w = c.weight * w[:, None, None, None]
  172. b = bn.bias - bn.running_mean * bn.weight / \
  173. (bn.running_var + bn.eps) ** 0.5
  174. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  175. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
  176. groups=self.c.groups)
  177. m.weight.data.copy_(w)
  178. m.bias.data.copy_(b)
  179. return m
  180. class BN_Linear(torch.nn.Sequential):
  181. def __init__(self, a, b, bias=True, std=0.02):
  182. super().__init__()
  183. self.add_module('bn', torch.nn.BatchNorm1d(a))
  184. self.add_module('l', torch.nn.Linear(a, b, bias=bias))
  185. trunc_normal_(self.l.weight, std=std)
  186. if bias:
  187. torch.nn.init.constant_(self.l.bias, 0)
  188. @torch.no_grad()
  189. def fuse(self):
  190. bn, l = self._modules.values()
  191. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  192. b = bn.bias - self.bn.running_mean * \
  193. self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  194. w = l.weight * w[None, :]
  195. if l.bias is None:
  196. b = b @ self.l.weight.T
  197. else:
  198. b = (l.weight @ b[:, None]).view(-1) + self.l.bias
  199. m = torch.nn.Linear(w.size(1), w.size(0))
  200. m.weight.data.copy_(w)
  201. m.bias.data.copy_(b)
  202. return m
  203. class PatchMerging(torch.nn.Module):
  204. def __init__(self, dim, out_dim):
  205. super().__init__()
  206. hid_dim = int(dim * 4)
  207. self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, )
  208. self.act = torch.nn.ReLU()
  209. self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, )
  210. self.se = SqueezeExcite(hid_dim, .25)
  211. self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, )
  212. def forward(self, x):
  213. x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
  214. return x
  215. class Residual(torch.nn.Module):
  216. def __init__(self, m, drop=0.):
  217. super().__init__()
  218. self.m = m
  219. self.drop = drop
  220. def forward(self, x):
  221. if self.training and self.drop > 0:
  222. return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
  223. device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  224. else:
  225. return x + self.m(x)
  226. class FFN(torch.nn.Module):
  227. def __init__(self, ed, h):
  228. super().__init__()
  229. self.pw1 = Conv2d_BN(ed, h)
  230. self.act = torch.nn.ReLU()
  231. self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0)
  232. def forward(self, x):
  233. x = self.pw2(self.act(self.pw1(x)))
  234. return x
  235. def nearest_multiple_of_16(n):
  236. if n % 16 == 0:
  237. return n
  238. else:
  239. lower_multiple = (n // 16) * 16
  240. upper_multiple = lower_multiple + 16
  241. if (n - lower_multiple) < (upper_multiple - n):
  242. return lower_multiple
  243. else:
  244. return upper_multiple
  245. class MobileMambaModule(torch.nn.Module):
  246. def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,
  247. kernels=3, ssm_ratio=1, forward_type="v052d", ):
  248. super().__init__()
  249. self.dim = dim
  250. self.global_channels = nearest_multiple_of_16(int(global_ratio * dim))
  251. if self.global_channels + int(local_ratio * dim) > dim:
  252. self.local_channels = dim - self.global_channels
  253. else:
  254. self.local_channels = int(local_ratio * dim)
  255. self.identity_channels = self.dim - self.global_channels - self.local_channels
  256. if self.local_channels != 0:
  257. self.local_op = DWConv2d_BN_ReLU(self.local_channels, self.local_channels, kernels)
  258. else:
  259. self.local_op = nn.Identity()
  260. if self.global_channels != 0:
  261. self.global_op = MBWTConv2d(self.global_channels, self.global_channels, kernels, wt_levels=1,
  262. ssm_ratio=ssm_ratio, forward_type=forward_type, )
  263. else:
  264. self.global_op = nn.Identity()
  265. self.proj = torch.nn.Sequential(torch.nn.ReLU(), Conv2d_BN(
  266. dim, dim, bn_weight_init=0, ))
  267. def forward(self, x): # x (B,C,H,W)
  268. x1, x2, x3 = torch.split(x, [self.global_channels, self.local_channels, self.identity_channels], dim=1)
  269. x1 = self.global_op(x1)
  270. x2 = self.local_op(x2)
  271. x = self.proj(torch.cat([x1, x2, x3], dim=1))
  272. return x
  273. class MobileMambaBlockWindow(torch.nn.Module):
  274. def __init__(self, dim, global_ratio=0.25, local_ratio=0.25,
  275. kernels=5, ssm_ratio=1, forward_type="v052d", ):
  276. super().__init__()
  277. self.dim = dim
  278. self.attn = MobileMambaModule(dim, global_ratio=global_ratio, local_ratio=local_ratio,
  279. kernels=kernels, ssm_ratio=ssm_ratio, forward_type=forward_type, )
  280. def forward(self, x):
  281. x = self.attn(x)
  282. return x
  283. class MobileMambaBlock(torch.nn.Module):
  284. def __init__(self, type,
  285. ed, global_ratio=0.25, local_ratio=0.25,
  286. kernels=5, drop_path=0., has_skip=True, ssm_ratio=1, forward_type="v052d"):
  287. super().__init__()
  288. self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
  289. self.ffn0 = Residual(FFN(ed, int(ed * 2)))
  290. if type == 's':
  291. self.mixer = Residual(MobileMambaBlockWindow(ed, global_ratio=global_ratio, local_ratio=local_ratio,
  292. kernels=kernels, ssm_ratio=ssm_ratio,
  293. forward_type=forward_type))
  294. self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., ))
  295. self.ffn1 = Residual(FFN(ed, int(ed * 2)))
  296. self.has_skip = has_skip
  297. self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
  298. def forward(self, x):
  299. shortcut = x
  300. x = self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
  301. x = (shortcut + self.drop_path(x)) if self.has_skip else x
  302. return x
  303. class MobileMamba(torch.nn.Module):
  304. def __init__(self, img_size=224,
  305. in_chans=3,
  306. num_classes=1000,
  307. stages=['s', 's', 's'],
  308. embed_dim=[192, 384, 448],
  309. global_ratio=[0.8, 0.7, 0.6],
  310. local_ratio=[0.2, 0.2, 0.3],
  311. depth=[1, 2, 2],
  312. kernels=[7, 5, 3],
  313. down_ops=[['subsample', 2], ['subsample', 2], ['']],
  314. distillation=False, drop_path=0., ssm_ratio=1, forward_type="v052d"):
  315. super().__init__()
  316. resolution = img_size
  317. # Patch embedding
  318. self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1),
  319. torch.nn.ReLU(),
  320. Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1,
  321. ), torch.nn.ReLU(),
  322. Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1,
  323. ), torch.nn.ReLU(),
  324. Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1,
  325. ))
  326. self.blocks1 = []
  327. self.blocks2 = []
  328. self.blocks3 = []
  329. dprs = [x.item() for x in torch.linspace(0, drop_path, sum(depth))]
  330. # Build MobileMamba blocks
  331. for i, (stg, ed, dpth, gr, lr, do) in enumerate(
  332. zip(stages, embed_dim, depth, global_ratio, local_ratio, down_ops)):
  333. dpr = dprs[sum(depth[:i]):sum(depth[:i + 1])]
  334. for d in range(dpth):
  335. eval('self.blocks' + str(i + 1)).append(
  336. MobileMambaBlock(stg, ed, gr, lr, kernels[i], dpr[d], ssm_ratio=ssm_ratio,
  337. forward_type=forward_type))
  338. if do[0] == 'subsample':
  339. # Build MobileMamba downsample block
  340. # ('Subsample' stride)
  341. blk = eval('self.blocks' + str(i + 2))
  342. blk.append(torch.nn.Sequential(Residual(
  343. Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
  344. Residual(FFN(embed_dim[i], int(embed_dim[i] * 2))), ))
  345. blk.append(PatchMerging(*embed_dim[i:i + 2]))
  346. blk.append(torch.nn.Sequential(Residual(
  347. Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], )),
  348. Residual(
  349. FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2))), ))
  350. self.blocks1 = torch.nn.Sequential(*self.blocks1)
  351. self.blocks2 = torch.nn.Sequential(*self.blocks2)
  352. self.blocks3 = torch.nn.Sequential(*self.blocks3)
  353. # Classification head
  354. self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
  355. self.distillation = distillation
  356. if distillation:
  357. self.head_dist = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
  358. @torch.jit.ignore
  359. def no_weight_decay(self):
  360. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  361. def forward(self, x):
  362. x = self.patch_embed(x)
  363. x = self.blocks1(x)
  364. x = self.blocks2(x)
  365. x = self.blocks3(x)
  366. x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
  367. if self.distillation:
  368. x = self.head(x), self.head_dist(x)
  369. if not self.training:
  370. x = (x[0] + x[1]) / 2
  371. else:
  372. x = self.head(x)
  373. return x
  374. def replace_batchnorm(net):
  375. for child_name, child in net.named_children():
  376. if hasattr(child, 'fuse'):
  377. fused = child.fuse()
  378. setattr(net, child_name, fused)
  379. replace_batchnorm(fused)
  380. elif isinstance(child, torch.nn.BatchNorm2d):
  381. setattr(net, child_name, torch.nn.Identity())
  382. else:
  383. replace_batchnorm(child)
  384. CFG_MobileMamba_T2 = {
  385. 'img_size': 192,
  386. 'embed_dim': [144, 272, 368],
  387. 'depth': [1, 2, 2],
  388. 'global_ratio': [0.8, 0.7, 0.6],
  389. 'local_ratio': [0.2, 0.2, 0.3],
  390. 'kernels': [7, 5, 3],
  391. 'drop_path': 0,
  392. 'ssm_ratio': 2,
  393. }
  394. CFG_MobileMamba_T4 = {
  395. 'img_size': 192,
  396. 'embed_dim': [176, 368, 448],
  397. 'depth': [1, 2, 2],
  398. 'global_ratio': [0.8, 0.7, 0.6],
  399. 'local_ratio': [0.2, 0.2, 0.3],
  400. 'kernels': [7, 5, 3],
  401. 'drop_path': 0,
  402. 'ssm_ratio': 2,
  403. }
  404. CFG_MobileMamba_S6 = {
  405. 'img_size': 224,
  406. 'embed_dim': [192, 384, 448],
  407. 'depth': [1, 2, 2],
  408. 'global_ratio': [0.8, 0.7, 0.6],
  409. 'local_ratio': [0.2, 0.2, 0.3],
  410. 'kernels': [7, 5, 3],
  411. 'drop_path': 0,
  412. 'ssm_ratio': 2,
  413. }
  414. CFG_MobileMamba_B1 = {
  415. 'img_size': 256,
  416. 'embed_dim': [200, 376, 448],
  417. 'depth': [2, 3, 2],
  418. 'global_ratio': [0.8, 0.7, 0.6],
  419. 'local_ratio': [0.2, 0.2, 0.3],
  420. 'kernels': [7, 5, 3],
  421. 'drop_path': 0.03,
  422. 'ssm_ratio': 2,
  423. }
  424. CFG_MobileMamba_B2 = {
  425. 'img_size': 384,
  426. 'embed_dim': [200, 376, 448],
  427. 'depth': [2, 3, 2],
  428. 'global_ratio': [0.8, 0.7, 0.6],
  429. 'local_ratio': [0.2, 0.2, 0.3],
  430. 'kernels': [7, 5, 3],
  431. 'drop_path': 0.03,
  432. 'ssm_ratio': 2,
  433. }
  434. CFG_MobileMamba_B4 = {
  435. 'img_size': 512,
  436. 'embed_dim': [200, 376, 448],
  437. 'depth': [2, 3, 2],
  438. 'global_ratio': [0.8, 0.7, 0.6],
  439. 'local_ratio': [0.2, 0.2, 0.3],
  440. 'kernels': [7, 5, 3],
  441. 'drop_path': 0.03,
  442. 'ssm_ratio': 2,
  443. }
  444. @MODEL.register_module
  445. def MobileMamba_T2(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  446. model_cfg=CFG_MobileMamba_T2):
  447. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  448. if fuse:
  449. replace_batchnorm(model)
  450. return model
  451. @MODEL.register_module
  452. def MobileMamba_T4(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  453. model_cfg=CFG_MobileMamba_T4):
  454. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  455. if fuse:
  456. replace_batchnorm(model)
  457. return model
  458. @MODEL.register_module
  459. def MobileMamba_S6(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  460. model_cfg=CFG_MobileMamba_S6):
  461. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  462. if fuse:
  463. replace_batchnorm(model)
  464. return model
  465. @MODEL.register_module
  466. def MobileMamba_B1(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  467. model_cfg=CFG_MobileMamba_B1):
  468. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  469. if fuse:
  470. replace_batchnorm(model)
  471. return model
  472. @MODEL.register_module
  473. def MobileMamba_B2(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  474. model_cfg=CFG_MobileMamba_B2):
  475. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  476. if fuse:
  477. replace_batchnorm(model)
  478. return model
  479. @MODEL.register_module
  480. def MobileMamba_B4(num_classes=1000, pretrained=False, distillation=False, fuse=False, pretrained_cfg=None,
  481. model_cfg=CFG_MobileMamba_B4):
  482. model = MobileMamba(num_classes=num_classes, distillation=distillation, **model_cfg)
  483. if fuse:
  484. replace_batchnorm(model)
  485. return model
  486. if __name__ == "__main__":
  487. from fvcore.nn import FlopCountAnalysis, flop_count_table, parameter_count
  488. from util.util import FLOPs, Throughput, get_val_dataloader
  489. import time
  490. import argparse
  491. def get_timepc():
  492. if torch.cuda.is_available():
  493. torch.cuda.synchronize()
  494. return time.perf_counter()
  495. model_dict = {
  496. "MobileMamba_T2": MobileMamba_T2,
  497. "MobileMamba_T4": MobileMamba_T4,
  498. "MobileMamba_S6": MobileMamba_S6,
  499. "MobileMamba_B1": MobileMamba_B1,
  500. "MobileMamba_B2": MobileMamba_B2,
  501. "MobileMamba_B4": MobileMamba_B4,
  502. }
  503. parser = argparse.ArgumentParser()
  504. parser.add_argument('-b', '--batchsize', type=int, default=256)
  505. parser.add_argument('-i', '--imagesize', type=int, default=224)
  506. parser.add_argument('-m', '--modelname', default="MobileMamba_S6")
  507. cfg = parser.parse_args()
  508. bs = cfg.batchsize
  509. img_size = cfg.imagesize
  510. model_name = cfg.modelname
  511. print('batch_size is:', bs, 'img_size is:', img_size, 'model_name is:', model_dict[model_name])
  512. gpu_id = 0
  513. speed = True
  514. latency = True
  515. with torch.no_grad():
  516. x = torch.randn(bs, 3, img_size, img_size)
  517. net = model_dict[model_name]()
  518. replace_batchnorm(net)
  519. net.eval()
  520. pre_cnt, cnt = 2, 5
  521. if gpu_id > -1:
  522. torch.cuda.set_device(gpu_id)
  523. x = x.cuda()
  524. net.cuda()
  525. pre_cnt, cnt = 50, 20
  526. FLOPs.fvcore_flop_count(net, torch.randn(1, 3, img_size, img_size).cuda(), show_arch=False)
  527. # GPU
  528. for _ in range(pre_cnt):
  529. net(x)
  530. t_s = get_timepc()
  531. for _ in range(cnt):
  532. net(x)
  533. t_e = get_timepc()
  534. speed = f'{bs * cnt / (t_e - t_s):>7.3f}'
  535. print(f'[Batchsize: {bs}]\t [GPU-Speed: {speed}]\t')