utils.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757
  1. import os
  2. import logging
  3. import sys
  4. import time
  5. import math
  6. from functools import partial
  7. from typing import Callable
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. from timm.utils import AverageMeter
  12. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from torchvision import datasets, transforms
  14. from torch.utils.data import DataLoader, RandomSampler
  15. from collections import OrderedDict
  16. import cv2
  17. import PIL
  18. import tqdm
  19. from PIL import Image
  20. import os
  21. import sys
  22. import torch
  23. import torch.nn as nn
  24. from torch import Tensor
  25. from torch.nn.modules import Module
  26. from functools import partial
  27. from typing import Callable, Tuple, Union, Tuple, Union, Any
  28. from collections import defaultdict
  29. HOME = os.environ["HOME"].rstrip("/")
  30. def import_abspy(name="models", path="classification/"):
  31. import sys
  32. import importlib
  33. path = os.path.abspath(path)
  34. assert os.path.isdir(path)
  35. sys.path.insert(0, path)
  36. module = importlib.import_module(name)
  37. sys.path.pop(0)
  38. return module
  39. def get_dataset(root="./val", img_size=224, ret="", crop=True, single_image=False):
  40. from torch.utils.data import SequentialSampler, DistributedSampler, DataLoader
  41. size = int((256 / 224) * img_size) if crop else int(img_size)
  42. transform = transforms.Compose([
  43. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  44. transforms.CenterCrop((img_size, img_size)),
  45. transforms.ToTensor(),
  46. transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25)),
  47. ])
  48. if single_image:
  49. class ds(datasets.ImageFolder):
  50. def __init__(self, img, transform):
  51. self.transform = transform
  52. self.target_transform = None
  53. self.loader = datasets.folder.default_loader
  54. self.samples = [(img, 0)]
  55. self.targets = [0]
  56. self.classes = ["none"]
  57. self.class_to_idx = {"none": 0}
  58. dataset = ds(root, transform=transform)
  59. else:
  60. dataset = datasets.ImageFolder(root, transform=transform)
  61. if ret in dataset.classes:
  62. print(f"found target {ret}", flush=True)
  63. target = dataset.class_to_idx[ret]
  64. dataset.samples = [s for s in dataset.samples if s[1] == target]
  65. dataset.targets = [s for s in dataset.targets if s == target]
  66. dataset.classes = [ret]
  67. dataset.class_to_idx = {ret: target}
  68. return dataset
  69. def show_mask_on_image(img: torch.Tensor, mask: torch.Tensor, mask_norm=True):
  70. H, W, C = img.shape
  71. mH, mW = mask.shape
  72. mask = torch.nn.functional.interpolate(mask[None, None], (H, W), mode="bilinear")[0, 0]
  73. if mask_norm:
  74. mask = (mask - mask.min()) / (mask.max() - mask.min())
  75. img = img.clamp(min=0, max=1).cpu().numpy()
  76. mask = mask.clamp(min=0, max=1).cpu().numpy()
  77. heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
  78. # heatmap = np.float32(heatmap) / 255
  79. # cam = heatmap + np.float32(img)
  80. # cam = cam / np.max(cam)
  81. return heatmap
  82. return np.uint8(255 * cam)
  83. def get_val_dataloader(batch_size=64, root="./val", img_size=224, sequential=True):
  84. import torch.utils.data
  85. size = int((256 / 224) * img_size)
  86. transform = transforms.Compose([
  87. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  88. transforms.CenterCrop((img_size, img_size)),
  89. transforms.ToTensor(),
  90. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  91. ])
  92. dataset = datasets.ImageFolder(root, transform=transform)
  93. if sequential:
  94. sampler = torch.utils.data.SequentialSampler(dataset)
  95. else:
  96. sampler = torch.utils.data.DistributedSampler(dataset)
  97. data_loader = torch.utils.data.DataLoader(
  98. dataset, sampler=sampler,
  99. batch_size=batch_size,
  100. shuffle=False,
  101. num_workers=0,
  102. pin_memory=True,
  103. drop_last=False
  104. )
  105. return data_loader
  106. class visualize:
  107. @staticmethod
  108. def get_colormap(name):
  109. import matplotlib as mpl
  110. """Handle changes to matplotlib colormap interface in 3.6."""
  111. try:
  112. return mpl.colormaps[name]
  113. except AttributeError:
  114. return mpl.cm.get_cmap(name)
  115. @staticmethod
  116. def draw_image_grid(image: Image, grid=[(0, 0, 1, 1)], **kwargs):
  117. # grid[0]: (x,y,w,h)
  118. default = dict(fill=None, outline='red', width=3)
  119. default.update(kwargs)
  120. assert isinstance(grid, list) and isinstance(grid[0], tuple) and len(grid[0]) == 4
  121. from PIL import ImageDraw
  122. a = ImageDraw.ImageDraw(image)
  123. for g in grid:
  124. a.rectangle([(g[0], g[1]), (g[0] + g[2], g[1] + g[3])], **default)
  125. return image
  126. @staticmethod
  127. def visualize_attnmap(attnmap, savefig="", figsize=(18, 16), cmap=None, sticks=True, dpi=400, fontsize=35, colorbar=True, **kwargs):
  128. import matplotlib.pyplot as plt
  129. if isinstance(attnmap, torch.Tensor):
  130. attnmap = attnmap.detach().cpu().numpy()
  131. # if isinstance(imgori, torch.Tensor):
  132. # imgori = imgori.detach().cpu().numpy()
  133. plt.rcParams["font.size"] = fontsize
  134. plt.figure(figsize=figsize, dpi=dpi, **kwargs)
  135. ax = plt.gca()
  136. im = ax.imshow(attnmap, cmap=cmap)
  137. # ax.set_title(title)
  138. if not sticks:
  139. ax.set_axis_off()
  140. if colorbar:
  141. cbar = ax.figure.colorbar(im, ax=ax)
  142. if savefig == "":
  143. plt.show()
  144. else:
  145. plt.savefig(savefig)
  146. plt.close()
  147. @staticmethod
  148. def visualize_attnmaps(attnmaps, savefig="", figsize=(18, 16), rows=1, cmap=None, dpi=400, fontsize=35, linewidth=2, **kwargs):
  149. # attnmaps: [(map, title), (map, title),...]
  150. import math
  151. import matplotlib.pyplot as plt
  152. vmin = min([np.min((a.detach().cpu().numpy() if isinstance(a, torch.Tensor) else a)) for a, t in attnmaps])
  153. vmax = max([np.max((a.detach().cpu().numpy() if isinstance(a, torch.Tensor) else a)) for a, t in attnmaps])
  154. cols = math.ceil(len(attnmaps) / rows)
  155. plt.rcParams["font.size"] = fontsize
  156. figsize=(cols * figsize[0], rows * figsize[1])
  157. fig, axs = plt.subplots(rows, cols, squeeze=False, sharex="all", sharey="all", figsize=figsize, dpi=dpi)
  158. for i in range(rows):
  159. for j in range(cols):
  160. idx = i * cols + j
  161. if idx >= len(attnmaps):
  162. image = np.zeros_like(image)
  163. title = "pad"
  164. else:
  165. image, title = attnmaps[idx]
  166. if isinstance(image, torch.Tensor):
  167. image = image.detach().cpu().numpy()
  168. im = axs[i, j].imshow(image, vmin=vmin, vmax=vmax, cmap=cmap)
  169. axs[i, j].set_title(title)
  170. axs[i, j].set_yticks([])
  171. axs[i, j].set_xticks([])
  172. print(title, "max", np.max(image), "min", np.min(image), end=" | ")
  173. print("")
  174. axs[0, 0].figure.colorbar(im, ax=axs)
  175. if savefig == "":
  176. plt.show()
  177. else:
  178. plt.savefig(savefig)
  179. plt.close()
  180. print("")
  181. @staticmethod
  182. def seanborn_heatmap(
  183. data, *,
  184. vmin=None, vmax=None, cmap=None, center=None, robust=False,
  185. annot=None, fmt=".2g", annot_kws=None,
  186. linewidths=0, linecolor="white",
  187. cbar=True, cbar_kws=None, cbar_ax=None,
  188. square=False, xticklabels="auto", yticklabels="auto",
  189. mask=None, ax=None,
  190. **kwargs
  191. ):
  192. from matplotlib import pyplot as plt
  193. from seaborn.matrix import _HeatMapper
  194. # Initialize the plotter object
  195. plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
  196. annot_kws, cbar, cbar_kws, xticklabels,
  197. yticklabels, mask)
  198. # Add the pcolormesh kwargs here
  199. kwargs["linewidths"] = linewidths
  200. kwargs["edgecolor"] = linecolor
  201. # Draw the plot and return the Axes
  202. if ax is None:
  203. ax = plt.gca()
  204. if square:
  205. ax.set_aspect("equal")
  206. plotter.plot(ax, cbar_ax, kwargs)
  207. mesh = ax.pcolormesh(plotter.plot_data, cmap=plotter.cmap, **kwargs)
  208. return ax, mesh
  209. @classmethod
  210. def visualize_snsmap(cls, attnmap, savefig="", figsize=(18, 16), cmap=None, sticks=True, dpi=80, fontsize=35, linewidth=2, **kwargs):
  211. import matplotlib.pyplot as plt
  212. if isinstance(attnmap, torch.Tensor):
  213. attnmap = attnmap.detach().cpu().numpy()
  214. plt.rcParams["font.size"] = fontsize
  215. plt.figure(figsize=figsize, dpi=dpi, **kwargs)
  216. ax = plt.gca()
  217. _, mesh = cls.seanborn_heatmap(attnmap, xticklabels=sticks, yticklabels=sticks, cmap=cmap, linewidths=0,
  218. center=0, annot=False, ax=ax, cbar=False, annot_kws={"size": 24}, fmt='.2f')
  219. cb = ax.figure.colorbar(mesh, ax=ax)
  220. cb.outline.set_linewidth(0)
  221. if savefig == "":
  222. plt.show()
  223. else:
  224. plt.savefig(savefig)
  225. plt.close()
  226. @classmethod
  227. def visualize_snsmaps(cls, attnmaps, savefig="", figsize=(18, 16), rows=1, cmap=None, sticks=True, dpi=80, fontsize=35, linewidth=2, **kwargs):
  228. # attnmaps: [(map, title), (map, title),...]
  229. import math
  230. import matplotlib.pyplot as plt
  231. vmin = min([np.min((a.detach().cpu().numpy() if isinstance(a, torch.Tensor) else a)) for a, t in attnmaps])
  232. vmax = max([np.max((a.detach().cpu().numpy() if isinstance(a, torch.Tensor) else a)) for a, t in attnmaps])
  233. cols = math.ceil(len(attnmaps) / rows)
  234. plt.rcParams["font.size"] = fontsize
  235. figsize=(cols * figsize[0], rows * figsize[1])
  236. fig, axs = plt.subplots(rows, cols, squeeze=False, sharex="all", sharey="all", figsize=figsize, dpi=dpi)
  237. for i in range(rows):
  238. for j in range(cols):
  239. idx = i * cols + j
  240. if idx >= len(attnmaps):
  241. image = np.zeros_like(image)
  242. title = "pad"
  243. else:
  244. image, title = attnmaps[idx]
  245. if isinstance(image, torch.Tensor):
  246. image = image.detach().cpu().numpy()
  247. _, im = cls.seanborn_heatmap(image, xticklabels=sticks, yticklabels=sticks,
  248. vmin=vmin, vmax=vmax, cmap=cmap,
  249. center=0, annot=False, ax=axs[i, j],
  250. cbar=False, annot_kws={"size": 24}, fmt='.2f')
  251. axs[i, j].set_title(title)
  252. cb = axs[0, 0].figure.colorbar(im, ax=axs)
  253. cb.outline.set_linewidth(0)
  254. if savefig == "":
  255. plt.show()
  256. else:
  257. plt.savefig(savefig)
  258. plt.close()
  259. # used for visualizing effective receiptive field
  260. class EffectiveReceiptiveField:
  261. @staticmethod
  262. def simpnorm(data):
  263. data = np.power(data, 0.2)
  264. data = data / np.max(data)
  265. return data
  266. @staticmethod
  267. def get_rectangle(data, thresh):
  268. h, w = data.shape
  269. all_sum = np.sum(data)
  270. for i in range(1, h // 2):
  271. selected_area = data[h // 2 - i:h // 2 + 1 + i, w // 2 - i:w // 2 + 1 + i]
  272. area_sum = np.sum(selected_area)
  273. if area_sum / all_sum > thresh:
  274. return i * 2 + 1, (i * 2 + 1) / h * (i * 2 + 1) / w
  275. return None, None
  276. @staticmethod
  277. def get_input_grad(model, samples, square=True):
  278. outputs = model(samples)
  279. out_size = outputs.size()
  280. if square:
  281. assert out_size[2] == out_size[3], out_size
  282. central_point = torch.nn.functional.relu(outputs[:, :, out_size[2] // 2, out_size[3] // 2]).sum()
  283. grad = torch.autograd.grad(central_point, samples)
  284. grad = grad[0]
  285. grad = torch.nn.functional.relu(grad)
  286. aggregated = grad.sum((0, 1))
  287. grad_map = aggregated.cpu().numpy()
  288. return grad_map
  289. @classmethod
  290. def get_input_grad_avg(cls, model: nn.Module, size=1024, data_path="ImageNet", num_images=50, norms=lambda x:x, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
  291. import tqdm
  292. from torchvision import datasets, transforms
  293. from torch.utils.data import DataLoader, RandomSampler
  294. transform = transforms.Compose([
  295. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  296. transforms.CenterCrop(size),
  297. transforms.ToTensor(),
  298. transforms.Normalize(mean, std)
  299. ])
  300. dataset = datasets.ImageFolder(os.path.join(data_path, 'val'), transform=transform)
  301. data_loader_val = DataLoader(dataset, sampler=RandomSampler(dataset), pin_memory=True)
  302. meter = AverageMeter()
  303. model.cuda().eval()
  304. for _, (samples, _) in tqdm.tqdm(enumerate(data_loader_val)):
  305. if meter.count == num_images:
  306. break
  307. samples = samples.cuda(non_blocking=True).requires_grad_()
  308. contribution_scores = cls.get_input_grad(model, samples)
  309. if np.isnan(np.sum(contribution_scores)):
  310. print("got nan | ", end="")
  311. continue
  312. else:
  313. meter.update(contribution_scores)
  314. return norms(meter.avg)
  315. # used for visualizing the attention of mamba
  316. class AttnMamba:
  317. @staticmethod
  318. def convert_state_dict_from_mmdet(state_dict):
  319. new_state_dict = OrderedDict()
  320. for k in state_dict:
  321. if k.startswith("backbone."):
  322. new_state_dict[k[len("backbone."):]] = state_dict[k]
  323. return new_state_dict
  324. @staticmethod
  325. def checkpostfix(tag, value):
  326. ret = value[-len(tag):] == tag
  327. if ret:
  328. value = value[:-len(tag)]
  329. return ret, value
  330. @staticmethod
  331. @torch.no_grad()
  332. def attnmap_mamba(regs, mode="CB", ret="all", absnorm=0, scale=1, verbose=False, device=None):
  333. printlog = print if verbose else lambda *args, **kwargs: None
  334. print(f"attn for mode={mode}, ret={ret}, absnorm={absnorm}, scale={scale}", flush=True)
  335. _norm = lambda x: x
  336. if absnorm == 1:
  337. _norm = lambda x: ((x - x.min()) / (x.max() - x.min()))
  338. elif absnorm == 2:
  339. _norm = lambda x: (x.abs() / x.abs().max())
  340. As, Bs, Cs, Ds = -torch.exp(regs["A_logs"].to(torch.float32)), regs["Bs"], regs["Cs"], regs["Ds"]
  341. us, dts, delta_bias = regs["us"], regs["dts"], regs["delta_bias"]
  342. ys, oy = regs["ys"], regs["y"]
  343. H, W = regs["H"], regs["W"]
  344. printlog(As.shape, Bs.shape, Cs.shape, Ds.shape, us.shape, dts.shape, delta_bias.shape)
  345. B, G, N, L = Bs.shape
  346. GD, N = As.shape
  347. D = GD // G
  348. H, W = int(math.sqrt(L)), int(math.sqrt(L))
  349. if device is not None:
  350. As, Bs, Cs, Ds, us, dts, delta_bias, ys, oy = As.to(device), Bs.to(device), Cs.to(device), Ds.to(device), us.to(device), dts.to(device), delta_bias.to(device), ys.to(device), oy.to(device)
  351. mask = torch.tril(dts.new_ones((L, L)))
  352. dts = torch.nn.functional.softplus(dts + delta_bias[:, None]).view(B, G, D, L)
  353. dw_logs = As.view(G, D, N)[None, :, :, :, None] * dts[:,:,:,None,:] # (B, G, D, N, L)
  354. ws = torch.cumsum(dw_logs, dim=-1).exp()
  355. if mode == "CB":
  356. Qs, Ks = Cs[:,:,None,:,:], Bs[:,:,None,:,:]
  357. elif mode == "CBdt":
  358. Qs, Ks = Cs[:,:,None,:,:], Bs[:,:,None,:,:] * dts.view(B, G, D, 1, L)
  359. elif mode == "CwBw":
  360. Qs, Ks = Cs[:,:,None,:,:] * ws, Bs[:,:,None,:,:] / ws.clamp(min=1e-20)
  361. elif mode == "CwBdtw":
  362. Qs, Ks = Cs[:,:,None,:,:] * ws, Bs[:,:,None,:,:] * dts.view(B, G, D, 1, L) / ws.clamp(min=1e-20)
  363. elif mode == "ww":
  364. Qs, Ks = ws, 1 / ws.clamp(min=1e-20)
  365. else:
  366. raise NotImplementedError
  367. printlog(ws.shape, Qs.shape, Ks.shape)
  368. printlog("Bs", Bs.max(), Bs.min(), Bs.abs().min())
  369. printlog("Cs", Cs.max(), Cs.min(), Cs.abs().min())
  370. printlog("ws", ws.max(), ws.min(), ws.abs().min())
  371. printlog("Qs", Qs.max(), Qs.min(), Qs.abs().min())
  372. printlog("Ks", Ks.max(), Ks.min(), Ks.abs().min())
  373. _Qs, _Ks = Qs.view(-1, N, L), Ks.view(-1, N, L)
  374. attns = (_Qs.transpose(1, 2) @ _Ks).view(B, G, -1, L, L)
  375. attns = attns.mean(dim=2) * mask
  376. attn0 = attns[:, 0, :].view(B, -1, L, L)
  377. attn1 = attns[:, 1, :].view(-1, H, W, H, W).permute(0, 2, 1, 4, 3).contiguous().view(B, -1, L, L)
  378. attn2 = attns[:, 2, :].view(-1, L, L).flip(dims=[-2]).flip(dims=[-1]).contiguous().view(B, -1, L, L)
  379. attn3 = attns[:, 3, :].view(-1, L, L).flip(dims=[-2]).flip(dims=[-1]).contiguous().view(B, -1, L, L)
  380. attn3 = attn3.view(-1, H, W, H, W).permute(0, 2, 1, 4, 3).contiguous().view(B, -1, L, L)
  381. # ao0, ao1, ao2, ao3: attntion in four directions without rearrange
  382. # a0, a1, a2, a3: attntion in four directions with rearrange
  383. # a0a2, a1a3, a0a1: combination of "a0, a1, a2, a3"
  384. # all: combination of all "a0, a1, a2, a3"
  385. if ret in ["ao0"]:
  386. attn = _norm(attns[:, 0, :]).view(B, -1, L, L).mean(dim=1)
  387. elif ret in ["ao1"]:
  388. attn = _norm(attns[:, 1, :]).view(B, -1, L, L).mean(dim=1)
  389. elif ret in ["ao2"]:
  390. attn = _norm(attns[:, 2, :]).view(B, -1, L, L).mean(dim=1)
  391. elif ret in ["ao3"]:
  392. attn = _norm(attns[:, 3, :]).view(B, -1, L, L).mean(dim=1)
  393. elif ret in ["a0"]:
  394. attn = _norm(attn0).mean(dim=1)
  395. elif ret in ["a1"]:
  396. attn = _norm(attn1).mean(dim=1)
  397. elif ret in ["a2"]:
  398. attn = _norm(attn2).mean(dim=1)
  399. elif ret in ["a3"]:
  400. attn = _norm(attn3).mean(dim=1)
  401. elif ret in ["all"]:
  402. attn = _norm((attn0 + attn1 + attn2 + attn3)).mean(dim=1)
  403. elif ret in ["nall"]:
  404. attn = (_norm(attn0) + _norm(attn1) + _norm(attn2) + _norm(attn3)).mean(dim=1) / 4.0
  405. else:
  406. raise NotImplementedError(f"{ret} is not allowed")
  407. attn = (scale * attn).clamp(max=attn.max())
  408. return attn[0], H, W
  409. @classmethod
  410. @torch.no_grad()
  411. def get_attnmap_mamba(cls, ss2ds, stage=-1, mode="", verbose=False, raw_attn=False, block_id=0, scale=1, device=None):
  412. mode1 = mode.split("_")[-1]
  413. mode = mode[:-(len(mode1) + 1)]
  414. absnorm = 0
  415. tag, mode = cls.checkpostfix("_absnorm", mode)
  416. absnorm = 2 if tag else absnorm
  417. tag, mode = cls.checkpostfix("_norm", mode)
  418. absnorm = 1 if tag else absnorm
  419. if raw_attn:
  420. ss2d = ss2ds if not isinstance(ss2ds, list) else ss2ds[stage][block_id]
  421. regs = getattr(ss2d, "__data__")
  422. attn, H, W = cls.attnmap_mamba(regs, mode=mode1, ret=mode, absnorm=absnorm, verbose=verbose, scale=scale)
  423. return attn
  424. allrolattn = None
  425. for k in range(len(ss2ds[stage])):
  426. regs = getattr(ss2ds[stage][k], "__data__")
  427. attn, H, W = cls.attnmap_mamba(regs, mode=mode1, ret=mode, absnorm=absnorm, verbose=verbose, scale=scale)
  428. L = H * W
  429. assert attn.shape == (L, L)
  430. assert attn.max() <= 1
  431. assert attn.min() >= 0
  432. rolattn = 0.5 * (attn.cpu() + torch.eye(L))
  433. rolattn = rolattn / rolattn.sum(-1)
  434. allrolattn = (rolattn @ allrolattn) if allrolattn is not None else rolattn
  435. return allrolattn
  436. # used for test throughput
  437. class Throughput:
  438. # default no amp in testing tp
  439. # copied from swin_transformer
  440. @staticmethod
  441. @torch.no_grad()
  442. def throughput(data_loader, model, logger=logging):
  443. model.eval()
  444. for idx, (images, _) in enumerate(data_loader):
  445. images = images.cuda(non_blocking=True)
  446. batch_size = images.shape[0]
  447. for i in range(50):
  448. model(images)
  449. torch.cuda.synchronize()
  450. logger.info(f"throughput averaged with 30 times")
  451. torch.cuda.reset_peak_memory_stats()
  452. tic1 = time.time()
  453. for i in range(30):
  454. model(images)
  455. torch.cuda.synchronize()
  456. tic2 = time.time()
  457. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  458. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  459. return
  460. @staticmethod
  461. @torch.no_grad()
  462. def throughputamp(data_loader, model, logger=logging):
  463. model.eval()
  464. for idx, (images, _) in enumerate(data_loader):
  465. images = images.cuda(non_blocking=True)
  466. batch_size = images.shape[0]
  467. for i in range(50):
  468. with torch.cuda.amp.autocast():
  469. model(images)
  470. torch.cuda.synchronize()
  471. logger.info(f"throughput averaged with 30 times")
  472. torch.cuda.reset_peak_memory_stats()
  473. tic1 = time.time()
  474. for i in range(30):
  475. with torch.cuda.amp.autocast():
  476. model(images)
  477. torch.cuda.synchronize()
  478. tic2 = time.time()
  479. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  480. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  481. return
  482. @staticmethod
  483. def testfwdbwd(data_loader, model, logger, amp=True):
  484. model.cuda().train()
  485. criterion = torch.nn.CrossEntropyLoss()
  486. for idx, (images, targets) in enumerate(data_loader):
  487. images = images.cuda(non_blocking=True)
  488. targets = targets.cuda(non_blocking=True)
  489. batch_size = images.shape[0]
  490. for i in range(50):
  491. with torch.cuda.amp.autocast(enabled=amp):
  492. out = model(images)
  493. loss = criterion(out, targets)
  494. loss.backward()
  495. torch.cuda.synchronize()
  496. logger.info(f"testfwdbwd averaged with 30 times")
  497. torch.cuda.reset_peak_memory_stats()
  498. tic1 = time.time()
  499. for i in range(30):
  500. with torch.cuda.amp.autocast(enabled=amp):
  501. out = model(images)
  502. loss = criterion(out, targets)
  503. loss.backward()
  504. torch.cuda.synchronize()
  505. tic2 = time.time()
  506. logger.info(f"batch_size {batch_size} testfwdbwd {30 * batch_size / (tic2 - tic1)}")
  507. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  508. return
  509. @classmethod
  510. def testall(cls, model, dataloader=None, data_path="", img_size=224, _batch_size=128, with_flops=True, inference_only=False):
  511. from fvcore.nn import parameter_count
  512. torch.cuda.empty_cache()
  513. model.cuda().eval()
  514. if with_flops:
  515. try:
  516. FLOPs.fvcore_flop_count(model, input_shape=(3, img_size, img_size), show_arch=False)
  517. except Exception as e:
  518. print("ERROR:", e, flush=True)
  519. print(parameter_count(model)[""], sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True)
  520. if dataloader is None:
  521. dataloader = get_val_dataloader(
  522. batch_size=_batch_size,
  523. root=os.path.join(os.path.abspath(data_path), "val"),
  524. img_size=img_size,
  525. )
  526. cls.throughput(data_loader=dataloader, model=model, logger=logging)
  527. if inference_only:
  528. return
  529. PASS = False
  530. batch_size = _batch_size
  531. while (not PASS) and (batch_size > 0):
  532. try:
  533. _dataloader = get_val_dataloader(
  534. batch_size=batch_size,
  535. root=os.path.join(os.path.abspath(data_path), "val"),
  536. img_size=img_size,
  537. )
  538. cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging)
  539. cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging, amp=False)
  540. PASS = True
  541. except:
  542. batch_size = batch_size // 2
  543. print(f"batch_size {batch_size}", flush=True)
  544. # used for extract features
  545. class ExtractFeatures:
  546. @staticmethod
  547. def get_list_dataset(*args, **kwargs):
  548. class DatasetList:
  549. def __init__(self, batch_size=16, root="train/", img_size=224, weak_aug=False):
  550. self.batch_size = int(batch_size)
  551. transform, transform_waug = self.get_transform(img_size)
  552. self.transform = transform_waug if weak_aug else transform
  553. self.dataset = datasets.ImageFolder(root, transform=self.transform)
  554. self.num_data = int(len(self.dataset))
  555. self.num_batches = math.ceil(self.num_data / self.batch_size)
  556. print(f"weak aug: {weak_aug} =========================", flush=True)
  557. @staticmethod
  558. def get_transform(img_size=224):
  559. size = int((256 / 224) * img_size)
  560. transform = transforms.Compose([
  561. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  562. transforms.CenterCrop((img_size, img_size)),
  563. transforms.ToTensor(),
  564. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  565. ])
  566. transform_waug = transforms.Compose([
  567. transforms.RandomResizedCrop(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
  568. transforms.RandomHorizontalFlip(),
  569. transforms.ToTensor(),
  570. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  571. ])
  572. return transform, transform_waug
  573. def __len__(self):
  574. return self.num_batches
  575. def __getitem__(self, idx):
  576. start = idx * self.batch_size
  577. end = min(start + self.batch_size, self.num_data)
  578. data = [self.dataset[i] for i in range(start, end)]
  579. images = torch.stack([img for img, tgt in data])
  580. targets = torch.stack([torch.tensor(tgt) for img, tgt in data])
  581. if len(images) < self.batch_size:
  582. _images = torch.zeros((self.batch_size, *data[0][0].shape))
  583. _targets = -1 * torch.ones((self.batch_size,))
  584. _images[:len(images)] = images
  585. _targets[:len(images)] = targets
  586. return _images, _targets
  587. return images, targets
  588. return DatasetList(*args, **kwargs)
  589. @classmethod
  590. def extract_feature(
  591. cls,
  592. backbones=dict(), # dict(name=model)
  593. batch_size=16,
  594. img_size=1024,
  595. data_path="ImageNet_ILSVRC2012",
  596. amp_disable=False,
  597. dims=dict(), # dict(name=dim)
  598. outdir=os.path.join(HOME, "ckpts/feats/unmerge/"),
  599. ranges=[0, 1000],
  600. train=True,
  601. aug=False,
  602. ):
  603. root = os.path.join(data_path, "./train") if train else os.path.join(data_path, "./val")
  604. datasetlist = cls.get_list_dataset(batch_size, root=root, img_size=img_size, weak_aug=aug)
  605. ranges = list(ranges)
  606. if ranges[1] <= 0:
  607. ranges[1] = len(datasetlist)
  608. ranges[1] = min(ranges[1], len(datasetlist))
  609. assert len(ranges) == 2 and ranges[1] > ranges[0], f"{ranges}"
  610. outbatches = ranges[1] - ranges[0]
  611. outdir = os.path.join(outdir, f"sz{img_size}_bs{batch_size}_range{ranges[0]}_{ranges[1]}" + ("" if train else "_val"))
  612. os.makedirs(outdir, exist_ok=True)
  613. backbones = {
  614. name: torch.nn.parallel.DistributedDataParallel(model.cuda().eval())
  615. for name, model in backbones.items()
  616. }
  617. feats = {
  618. name: torch.zeros((outbatches, batch_size, dim))
  619. for name, dim in dims.items()
  620. }
  621. all_targets = torch.zeros((outbatches, batch_size))
  622. print("=" * 50, flush=True)
  623. print(f"using backbones {backbones.keys()}", flush=True)
  624. print(f"batch_size {batch_size} img_size {img_size} ranges {ranges} max_range {0} {len(datasetlist)}", flush=True)
  625. with torch.no_grad():
  626. for i, idx in enumerate(tqdm.tqdm(range(ranges[0], ranges[1]))):
  627. images, targets = datasetlist[idx]
  628. images = images.cuda(non_blocking=True)
  629. all_targets[i] = targets.detach().cpu()
  630. for name, model in backbones.items():
  631. with torch.cuda.amp.autocast(enabled=(not amp_disable)):
  632. feats[name][i] = model(images).detach().cpu()
  633. for name, model in backbones.items():
  634. na = f"{name}_bs{batch_size}_sz{img_size}_obs{outbatches}_s{ranges[0]}_e{ranges[1]}.pth"
  635. torch.save(feats[name], open(os.path.join(outdir, na), "wb"))
  636. na = f"targets_bs{batch_size}_sz{img_size}_obs{outbatches}_s{ranges[0]}_e{ranges[1]}.pth"
  637. torch.save(all_targets, open(os.path.join(outdir, na), "wb"))
  638. @staticmethod
  639. def merge_feats(features=[], targets=[], length=1281167, save="/tmp/1.pth"):
  640. feats = [torch.load(open(f, "rb")) for f in features]
  641. tgts = [torch.load(open(f, "rb")) for f in targets]
  642. for i, (f, t) in enumerate(zip(feats, tgts)):
  643. assert f.shape[0:2] == t.shape[0:2], breakpoint()
  644. assert sum([t.shape[0] for t in tgts]) * tgts[0].shape[1] >= length
  645. print(features, targets, flush=True)
  646. feats = torch.cat(feats, dim=0).view(-1, feats[0].shape[-1])
  647. tgts = torch.cat(tgts, dim=0).view(-1)
  648. if not (len(feats) == length):
  649. assert (feats[length:] == feats[length]).all() # input 0, models output same
  650. assert (feats[length] != feats[length - 1]).any()
  651. assert (tgts[length:] == -1).all()
  652. assert (tgts[:length] != -1).all()
  653. feats = feats[:length]
  654. tgts = tgts[:length]
  655. os.makedirs(os.path.dirname(save), exist_ok=True)
  656. assert not os.path.exists(save), f"file {save} exist"
  657. torch.save(dict(features=feats, targets=tgts), open(save, "wb"))
  658. # used for build models
  659. class BuildModels:
  660. @staticmethod
  661. def build_vheat(with_ckpt=False, remove_head=False, only_backbone=False, scale="small", size=224):
  662. assert not with_ckpt
  663. assert not remove_head
  664. assert not only_backbone
  665. print("vheat ================================", flush=True)
  666. _model = import_abspy("vheat", f"{HOME}/packs/VHeat/classification/models")
  667. VHEAT = _model.HeatM_V2_Stem_Noangle_Freqembed_Oldhead_Fast2_Torelease
  668. tiny = partial(VHEAT, depths=[2, 2, 6, 2], dims=96, img_size=size, infer_mode=True)
  669. small = partial(VHEAT, depths=[2, 2, 18, 2], dims=96, img_size=size, infer_mode=True)
  670. base = partial(VHEAT, depths=[2, 2, 18, 2], dims=128, img_size=size, infer_mode=True)
  671. model = dict(tiny=tiny, small=small, base=base)[scale]()
  672. model.infer_init()
  673. return model
  674. @staticmethod
  675. def build_visionmamba(with_ckpt=False, remove_head=False, only_backbone=False, scale="small", size=224):
  676. assert not with_ckpt
  677. assert not remove_head
  678. assert not only_backbone
  679. print("vim ================================", flush=True)
  680. specpath = f"{HOME}/packs/Vim/mamba-1p1p1"
  681. sys.path.insert(0, specpath)
  682. import mamba_ssm
  683. _model = import_abspy("models_mamba", f"{HOME}/packs/Vim/vim")
  684. model = _model.vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2()
  685. sys.path = sys.path[1:]
  686. return model
  687. @staticmethod
  688. def build_s4nd(with_ckpt=False, remove_head=False, only_backbone=False, scale="ctiny", size=224):
  689. assert not with_ckpt
  690. assert not remove_head
  691. assert scale in ["vitb", "ctiny"]
  692. print("convnext-s4nd ================================", flush=True)
  693. specpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "./convnexts4nd")
  694. sys.path.insert(0, specpath)
  695. import timm; assert timm.__version__ == "0.5.4"
  696. import structured_kernels
  697. model = import_abspy("vit_all", f"{os.path.dirname(__file__)}/convnexts4nd")
  698. vitb = model.vit_base_s4nd
  699. model = import_abspy("convnext_timm", f"{os.path.dirname(__file__)}/convnexts4nd")
  700. ctiny = model.convnext_tiny_s4nd
  701. model = dict(ctiny=ctiny, vitb=vitb)[scale]()
  702. sys.path = sys.path[1:]
  703. if only_backbone:
  704. model.forward = model.forward_features
  705. return model
  706. @staticmethod
  707. def build_vmamba(with_ckpt=False, remove_head=False, only_backbone=False, scale="tv0", size=224, cfg=None, ckpt=None, key="model"):
  708. print("vssm ================================", flush=True)
  709. _model = import_abspy("vmamba", f"{os.path.dirname(__file__)}/../classification/models")
  710. if scale == "flex":
  711. model = _model.VSSM(**cfg)
  712. ckpt = ckpt
  713. else:
  714. tv2 = (
  715. partial(_model.VSSM, dims=96, depths=[2,2,8,2], ssm_d_state=1, ssm_dt_rank="auto", ssm_ratio=1.0, ssm_conv=3, ssm_conv_bias=False, forward_type="v05_noz", mlp_ratio=4.0, downsample_version="v3", patchembed_version="v2", norm_layer="ln2d"),
  716. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230s/vssm1_tiny_0230s_ckpt_epoch_264.pth"
  717. )
  718. sv2 = (
  719. partial(_model.VSSM, dims=96, depths=[2,2,15,2], ssm_d_state=1, ssm_dt_rank="auto", ssm_ratio=2.0, ssm_conv=3, ssm_conv_bias=False, forward_type="v05_noz", mlp_ratio=4.0, downsample_version="v3", patchembed_version="v2", norm_layer="ln2d"),
  720. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_small_0229/vssm1_small_0229_ckpt_epoch_222.pth"
  721. )
  722. bv2 = (
  723. partial(_model.VSSM, dims=128, depths=[2,2,15,2], ssm_d_state=1, ssm_dt_rank="auto", ssm_ratio=2.0, ssm_conv=3, ssm_conv_bias=False, forward_type="v05_noz", mlp_ratio=4.0, downsample_version="v3", patchembed_version="v2", norm_layer="ln2d"),
  724. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_base_0229/vssm1_base_0229_ckpt_epoch_237.pth"
  725. )
  726. tv1 = (
  727. partial(_model.VSSM, dims=96, depths=[2,2,5,2], ssm_d_state=1, ssm_dt_rank="auto", ssm_ratio=2.0, ssm_conv=3, ssm_conv_bias=False, forward_type="v05_noz", mlp_ratio=4.0, downsample_version="v3", patchembed_version="v2", norm_layer="ln2d"),
  728. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230/vssm1_tiny_0230_ckpt_epoch_262.pth"
  729. )
  730. tv0 = (
  731. partial(_model.VSSM, dims=96, depths=[2,2,9,2], ssm_d_state=16, ssm_dt_rank="auto", ssm_ratio=2.0, forward_type="v05", mlp_ratio=0.0, downsample_version="v1", patchembed_version="v1", norm_layer="ln2d"),
  732. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm/classification/vssmtiny/vssmtiny_dp01_ckpt_epoch_292.pth"
  733. )
  734. sv0 = (
  735. partial(_model.VSSM, dims=96, depths=[2,2,27,2], ssm_d_state=16, ssm_dt_rank="auto", ssm_ratio=2.0, forward_type="v05", mlp_ratio=0.0, downsample_version="v1", patchembed_version="v1", norm_layer="ln2d"),
  736. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm/classification/vssmsmall/vssmsmall_dp03_ckpt_epoch_238.pth"
  737. )
  738. bv0 = (
  739. partial(_model.VSSM, dims=128, depths=[2,2,27,2], ssm_d_state=16, ssm_dt_rank="auto", ssm_ratio=2.0, forward_type="v05", mlp_ratio=0.0, downsample_version="v1", patchembed_version="v1", norm_layer="ln2d"),
  740. f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm/classification/vssmbase/vssmbase_dp06_ckpt_epoch_241.pth"
  741. )
  742. model = dict(tv0=tv0, tv1=tv1, tv2=tv2, sv0=sv0, sv2=sv2, bv0=bv0, bv2=bv2)[scale][0]()
  743. ckpt = dict(tv0=tv0, tv1=tv1, tv2=tv2, sv0=sv0, sv2=sv2, bv0=bv0, bv2=bv2)[scale][1]
  744. if with_ckpt:
  745. model.load_state_dict(torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))[key])
  746. if remove_head:
  747. print(model.classifier.head, flush=True)
  748. model.classifier.head = nn.Identity() # 768->1000
  749. elif only_backbone:
  750. def _forward(self, x: torch.Tensor):
  751. x = self.patch_embed(x)
  752. if self.pos_embed is not None:
  753. pos_embed = self.pos_embed.permute(0, 2, 3, 1) if not self.channel_first else self.pos_embed
  754. x = x + pos_embed
  755. for layer in self.layers:
  756. x = layer(x)
  757. if not self.channel_first:
  758. x = x.permute(0, 3, 1, 2).contiguous()
  759. return x
  760. model.forward = partial(_forward, model)
  761. return model
  762. @staticmethod
  763. def build_swin(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  764. print("swin ================================", flush=True)
  765. specpath = f"{HOME}/packs/Swin-Transformer"
  766. sys.path.insert(0, specpath)
  767. import swin_window_process
  768. _model = import_abspy("swin_transformer", f"{HOME}/packs/Swin-Transformer/models")
  769. # configs/swin/swin_tiny_patch4_window7_224.yaml
  770. tiny = partial(_model.SwinTransformer, embed_dim=96, depths=[2,2,6,2], num_heads=[ 3, 6, 12, 24 ], img_size=size, window_size=(size//32), fused_window_process=True)
  771. # configs/swin/swin_small_patch4_window7_224.yaml
  772. small = partial(_model.SwinTransformer, embed_dim=96, depths=[2,2,18,2], num_heads=[ 3, 6, 12, 24 ], img_size=size, window_size=(size//32), fused_window_process=True)
  773. # # configs/swin/swin_base_patch4_window7_224.yaml
  774. base = partial(_model.SwinTransformer, embed_dim=128, depths=[2,2,18,2], num_heads=[ 4, 8, 16, 32 ], img_size=size, window_size=(size//32), fused_window_process=True)
  775. sys.path = sys.path[1:]
  776. model = dict(tiny=tiny, small=small, base=base)[scale]()
  777. if with_ckpt:
  778. assert size == 224, "only support size 224"
  779. assert scale == "tiny", "support tiny with ckpt only"
  780. ckpt = f"{HOME}/packs/ckpts/swin_tiny_patch4_window7_224.pth"
  781. model.load_state_dict(torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))["model"])
  782. if remove_head:
  783. print(model.head, flush=True)
  784. model.head = nn.Identity()
  785. elif only_backbone:
  786. def _forward(self, x):
  787. x = self.patch_embed(x)
  788. if self.ape:
  789. x = x + self.absolute_pos_embed
  790. x = self.pos_drop(x)
  791. for layer in self.layers:
  792. x = layer(x)
  793. x = x.permute(0, 2, 1)
  794. x = x.view(*x.shape[0:2], int(math.sqrt(x.shape[-1])), int(math.sqrt(x.shape[-1])))
  795. return x
  796. model.forward = partial(_forward, model)
  797. return model
  798. @staticmethod
  799. def build_convnext(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  800. print("convnext ================================", flush=True)
  801. _model = import_abspy("convnext", f"{HOME}/packs/ConvNeXt/models")
  802. tiny = _model.convnext_tiny()
  803. small = _model.convnext_small()
  804. base = _model.convnext_base()
  805. model = dict(tiny=tiny, small=small, base=base)[scale]
  806. if with_ckpt:
  807. assert scale == "tiny", "support tiny with ckpt only"
  808. ckpt =f"{HOME}/packs/ckpts/convnext_tiny_1k_224_ema.pth"
  809. model.load_state_dict(torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))["model"])
  810. if remove_head:
  811. print(model.head, flush=True)
  812. model.head = nn.Identity() # 768
  813. elif only_backbone:
  814. def _forward(self, x):
  815. for i in range(4):
  816. x = self.downsample_layers[i](x)
  817. x = self.stages[i](x)
  818. return x
  819. model.forward = partial(_forward, model)
  820. return model
  821. @staticmethod
  822. def build_hivit(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  823. print("hivit [for testing throughput only] ================================", flush=True)
  824. sys.path.insert(0, "")
  825. _model = import_abspy("hivit", f"{HOME}/packs/hivit/supervised/models/")
  826. tiny = partial(_model.HiViT, img_size=size, patch_size=16, inner_patches=4, embed_dim=384, depths=[1, 1, 10], num_heads=6, stem_mlp_ratio=3., mlp_ratio=4., ape=True, rpe=True,)
  827. small = partial(_model.HiViT, img_size=size, patch_size=16, inner_patches=4, embed_dim=384, depths=[2, 2, 20], num_heads=6, stem_mlp_ratio=3., mlp_ratio=4., ape=True, rpe=True,)
  828. base = partial(_model.HiViT, img_size=size, patch_size=16, inner_patches=4, embed_dim=512, depths=[2, 2, 20], num_heads=8, stem_mlp_ratio=3., mlp_ratio=4., ape=True, rpe=True,)
  829. sys.path = sys.path[1:]
  830. model = dict(tiny=tiny, small=small, base=base)[scale]()
  831. if with_ckpt:
  832. assert NotImplementedError
  833. if remove_head:
  834. assert NotImplementedError
  835. elif only_backbone:
  836. assert NotImplementedError
  837. return model
  838. @staticmethod
  839. def build_intern(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  840. print("intern ================================", flush=True)
  841. specpath = f"{HOME}/packs/InternImage/classification"
  842. sys.path.insert(0, specpath)
  843. import DCNv3
  844. _model = import_abspy("intern_image", f"{HOME}/packs/InternImage/classification/models/")
  845. sys.path = sys.path[1:]
  846. tiny = partial(_model.InternImage, core_op='DCNv3', channels=64, depths=[4, 4, 18, 4], groups=[4, 8, 16, 32], offset_scale=1.0, mlp_ratio=4.,)
  847. small = partial(_model.InternImage, core_op='DCNv3', channels=80, depths=[4, 4, 21, 4], groups=[5, 10, 20, 40], layer_scale=1e-5, offset_scale=1.0, mlp_ratio=4., post_norm=True)
  848. base = partial(_model.InternImage, core_op='DCNv3', channels=112, depths=[4, 4, 21, 4], groups=[7, 14, 28, 56], layer_scale=1e-5, offset_scale=1.0, mlp_ratio=4., post_norm=True)
  849. model = dict(tiny=tiny, small=small, base=base)[scale]()
  850. if with_ckpt:
  851. assert scale == "tiny", "only support tiny model"
  852. ckpt = f"{HOME}/packs/ckpts/internimage_t_1k_224.pth"
  853. model.load_state_dict(torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))["model"])
  854. if remove_head:
  855. print(model.head, flush=True) # 768
  856. model.head = nn.Identity()
  857. elif only_backbone:
  858. def forward(self, x):
  859. x = self.patch_embed(x)
  860. x = self.pos_drop(x)
  861. for level in self.levels:
  862. x = level(x)
  863. return x.permute(0, 3, 1, 2)
  864. model.forward = partial(forward, model)
  865. return model
  866. @staticmethod
  867. def build_xcit(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  868. print("xcit =================", flush=True)
  869. xcit = import_abspy("xcit", f"{HOME}/packs/xcit/")
  870. model = dict(tiny=xcit.xcit_small_12_p16, small=xcit.xcit_small_24_p16, base=xcit.xcit_medium_24_p16)[scale]()
  871. if with_ckpt:
  872. assert scale == "tiny", "only support tiny for ckpt"
  873. ckpt = f"{HOME}/packs/ckpts/xcit_small_12_p16_224.pth"
  874. model.load_state_dict(torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))["model"])
  875. if remove_head:
  876. print(model.head, flush=True)
  877. def forward(self, x):
  878. x = self.forward_features(x)
  879. return x
  880. model.forward = partial(forward, model)
  881. elif only_backbone:
  882. def _forward(self, x):
  883. B, C, H, W = x.shape
  884. x, (Hp, Wp) = self.patch_embed(x)
  885. if self.use_pos:
  886. pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  887. x = x + pos_encoding
  888. x = self.pos_drop(x)
  889. for blk in self.blocks:
  890. x = blk(x, Hp, Wp)
  891. cls_tokens = self.cls_token.expand(B, -1, -1)
  892. x = torch.cat((cls_tokens, x), dim=1)
  893. for blk in self.cls_attn_blocks:
  894. x = blk(x, Hp, Wp)
  895. x = x[:, 1:, :].permute(0, 2, 1)
  896. x = x.view(*x.shape[0:2], int(math.sqrt(x.shape[-1])), int(math.sqrt(x.shape[-1])),)
  897. return x
  898. model.forward = partial(_forward, model)
  899. else:
  900. def forward(self, x):
  901. x = self.forward_features(x)
  902. return x
  903. model.forward = partial(forward, model)
  904. return model
  905. @staticmethod
  906. def build_swin_mmpretrain(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  907. print("swin scale [do not test throughput with this]================================", flush=True)
  908. from mmengine.runner import CheckpointLoader
  909. from mmpretrain.models import build_classifier, ImageClassifier
  910. model = dict(
  911. type='ImageClassifier',
  912. backbone=dict(
  913. type='SwinTransformer', arch=scale, img_size=224, drop_path_rate=0.2),
  914. neck=dict(type='GlobalAveragePooling'),
  915. head=dict(
  916. type='LinearClsHead',
  917. num_classes=1000,
  918. in_channels=1024 if scale == "base" else 768,
  919. init_cfg=None, # suppress the default init_cfg of LinearClsHead.
  920. loss=dict(
  921. type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
  922. cal_acc=False),
  923. init_cfg=[
  924. dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
  925. dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
  926. ],
  927. train_cfg=dict(augments=[
  928. dict(type='Mixup', alpha=0.8),
  929. dict(type='CutMix', alpha=1.0)
  930. ]),
  931. )
  932. ckpt = "https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth"
  933. model["backbone"].update({"window_size": int(size // 32)})
  934. model: ImageClassifier = build_classifier(model)
  935. if with_ckpt:
  936. assert scale == "tiny", "support tiny with ckpt only"
  937. model.load_state_dict(CheckpointLoader.load_checkpoint(ckpt)['state_dict'], strict=False)
  938. if remove_head:
  939. print(model.head.fc, flush=True) # 768
  940. model.head.fc = nn.Identity()
  941. elif only_backbone:
  942. def forward_backbone(self: ImageClassifier, x):
  943. x = self.backbone(x)[-1]
  944. return x
  945. model.forward = partial(forward_backbone, model)
  946. return model
  947. @staticmethod
  948. def build_hivit_mmpretrain(with_ckpt=False, remove_head=False, only_backbone=False, scale="tiny", size=224):
  949. assert scale == "tiny", "support tiny only"
  950. print("hivit scale [do not test throughput with this]================================", flush=True)
  951. from mmpretrain.models.builder import MODELS
  952. from mmengine.runner import CheckpointLoader
  953. from mmpretrain.models import build_classifier, ImageClassifier, HiViT, SwinTransformer
  954. from mmpretrain.models.backbones.vision_transformer import resize_pos_embed, to_2tuple, np
  955. class _HiViTx(HiViT):
  956. def __init__(self, *args,**kwargs):
  957. super().__init__(*args,**kwargs)
  958. self.num_extra_tokens = 0
  959. self.interpolate_mode = "bicubic"
  960. self.patch_embed.init_out_size = self.patch_embed.patches_resolution
  961. self._register_load_state_dict_pre_hook(self._prepare_abs_pos_embed)
  962. self._register_load_state_dict_pre_hook(
  963. self._prepare_relative_position_bias_table)
  964. # copied from SwinTransformer, change absolute_pos_embed to pos_embed
  965. def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
  966. name = prefix + 'pos_embed'
  967. if name not in state_dict.keys():
  968. return
  969. ckpt_pos_embed_shape = state_dict[name].shape
  970. if self.pos_embed.shape != ckpt_pos_embed_shape:
  971. from mmengine.logging import MMLogger
  972. logger = MMLogger.get_current_instance()
  973. logger.info(
  974. 'Resize the pos_embed shape from '
  975. f'{ckpt_pos_embed_shape} to {self.pos_embed.shape}.')
  976. ckpt_pos_embed_shape = to_2tuple(
  977. int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
  978. pos_embed_shape = self.patch_embed.init_out_size
  979. state_dict[name] = resize_pos_embed(state_dict[name],
  980. ckpt_pos_embed_shape,
  981. pos_embed_shape,
  982. self.interpolate_mode,
  983. self.num_extra_tokens)
  984. def _prepare_relative_position_bias_table(self, state_dict, *args, **kwargs):
  985. del state_dict['backbone.relative_position_index']
  986. aaa = SwinTransformer._prepare_relative_position_bias_table(self, state_dict, *args, **kwargs)
  987. return aaa
  988. try:
  989. @MODELS.register_module()
  990. class HiViTx(_HiViTx):
  991. ...
  992. except Exception as e:
  993. print(e)
  994. print("hivit ================================", flush=True)
  995. model = dict(
  996. backbone=dict(
  997. ape=True,
  998. arch='tiny',
  999. drop_path_rate=0.05,
  1000. img_size=224,
  1001. rpe=True,
  1002. type='HiViTx'),
  1003. head=dict(
  1004. cal_acc=False,
  1005. in_channels=384,
  1006. init_cfg=None,
  1007. loss=dict(
  1008. label_smooth_val=0.1, mode='original', type='LabelSmoothLoss'),
  1009. num_classes=1000,
  1010. type='LinearClsHead'),
  1011. init_cfg=[
  1012. dict(bias=0.0, layer='Linear', std=0.02, type='TruncNormal'),
  1013. dict(bias=0.0, layer='LayerNorm', type='Constant', val=1.0),
  1014. ],
  1015. neck=dict(type='GlobalAveragePooling'),
  1016. train_cfg=dict(augments=[
  1017. dict(alpha=0.8, type='Mixup'),
  1018. dict(alpha=1.0, type='CutMix'),
  1019. ]),
  1020. type='ImageClassifier')
  1021. model["backbone"].update({"img_size": size})
  1022. model = build_classifier(model)
  1023. if with_ckpt:
  1024. assert scale == "tiny", "support tiny with ckpt only"
  1025. ckpt = f"{HOME}/packs/ckpts/hivit-tiny-p16_8xb128_in1k/epoch_295.pth"
  1026. model.load_state_dict(CheckpointLoader.load_checkpoint(ckpt)['state_dict'], strict=False)
  1027. if remove_head:
  1028. print(model.head.fc, flush=True) # 768
  1029. model.head.fc = nn.Identity()
  1030. elif only_backbone:
  1031. def forward_backbone(self: ImageClassifier, x):
  1032. x = self.backbone(x)[-1]
  1033. return x
  1034. model.forward = partial(forward_backbone, model)
  1035. return model
  1036. @staticmethod
  1037. def build_deit_mmpretrain(with_ckpt=False, remove_head=False, only_backbone=False, scale="small", size=224, test_flops=False):
  1038. print("deit ================================", flush=True)
  1039. from mmengine.runner import CheckpointLoader
  1040. from mmpretrain.models import build_classifier, ImageClassifier, HiViT, VisionTransformer, SwinTransformer
  1041. from mmpretrain.models.backbones.vision_transformer import resize_pos_embed, to_2tuple, np
  1042. small = dict(
  1043. type='ImageClassifier',
  1044. backbone=dict(
  1045. type='VisionTransformer',
  1046. arch='deit-small',
  1047. img_size=size,
  1048. patch_size=16),
  1049. neck=None,
  1050. head=dict(
  1051. type='VisionTransformerClsHead',
  1052. num_classes=1000,
  1053. in_channels=384,
  1054. loss=dict(
  1055. type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
  1056. ),
  1057. init_cfg=[
  1058. dict(type='TruncNormal', layer='Linear', std=.02),
  1059. dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
  1060. ],
  1061. train_cfg=dict(augments=[
  1062. dict(type='Mixup', alpha=0.8),
  1063. dict(type='CutMix', alpha=1.0)
  1064. ]),
  1065. )
  1066. base = dict(
  1067. type='ImageClassifier',
  1068. backbone=dict(
  1069. type='VisionTransformer',
  1070. arch='deit-base',
  1071. img_size=size,
  1072. patch_size=16,
  1073. drop_path_rate=0.1),
  1074. neck=None,
  1075. head=dict(
  1076. type='VisionTransformerClsHead',
  1077. num_classes=1000,
  1078. in_channels=768,
  1079. loss=dict(
  1080. type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
  1081. ),
  1082. init_cfg=[
  1083. dict(type='TruncNormal', layer='Linear', std=.02),
  1084. dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
  1085. ],
  1086. train_cfg=dict(augments=[
  1087. dict(type='Mixup', alpha=0.8),
  1088. dict(type='CutMix', alpha=1.0)
  1089. ]),
  1090. )
  1091. model = dict(small=small, base=base)[scale]
  1092. ckpt = dict(
  1093. small="https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth",
  1094. base="https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth",
  1095. )[scale]
  1096. model = build_classifier(model)
  1097. if with_ckpt:
  1098. model.load_state_dict(CheckpointLoader.load_checkpoint(ckpt)['state_dict'], strict=False)
  1099. if remove_head:
  1100. print(model.head.layers.head, flush=True)
  1101. model.head.layers.head = nn.Identity() # 384->1000
  1102. elif only_backbone:
  1103. model.backbone.out_type = 'featmap'
  1104. def forward_backbone(self: ImageClassifier, x):
  1105. x = self.backbone(x)[-1]
  1106. return x
  1107. model.forward = partial(forward_backbone, model)
  1108. if test_flops:
  1109. print("WARNING: this mode may make throughput lower, used to test flops only!", flush=True)
  1110. from mmpretrain.models.utils.attention import scaled_dot_product_attention_pyimpl
  1111. for layer in model.backbone.layers:
  1112. layer.attn.scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
  1113. else:
  1114. print("WARNING: this mode will make flops lower, do not use this to test flops!", flush=True)
  1115. return model
  1116. @staticmethod
  1117. def build_resnet_mmpretrain(with_ckpt=False, remove_head=False, only_backbone=False, scale="r50", size=224):
  1118. print("resnet ================================", flush=True)
  1119. from mmengine.runner import CheckpointLoader
  1120. from mmpretrain.models import build_classifier, ImageClassifier
  1121. r50 = dict(
  1122. type='ImageClassifier',
  1123. backbone=dict(
  1124. type='ResNet',
  1125. depth=50,
  1126. num_stages=4,
  1127. out_indices=(3, ),
  1128. style='pytorch'),
  1129. neck=dict(type='GlobalAveragePooling'),
  1130. head=dict(
  1131. type='LinearClsHead',
  1132. num_classes=1000,
  1133. in_channels=2048,
  1134. loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
  1135. topk=(1, 5),
  1136. ))
  1137. r101 = dict(
  1138. type='ImageClassifier',
  1139. backbone=dict(
  1140. type='ResNet',
  1141. depth=101,
  1142. num_stages=4,
  1143. out_indices=(3, ),
  1144. style='pytorch'),
  1145. neck=dict(type='GlobalAveragePooling'),
  1146. head=dict(
  1147. type='LinearClsHead',
  1148. num_classes=1000,
  1149. in_channels=2048,
  1150. loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
  1151. topk=(1, 5),
  1152. ))
  1153. model = dict(r50=r50, r101=r101)[scale]
  1154. ckpt = dict(
  1155. r50="https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
  1156. r101="https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
  1157. )[scale]
  1158. model = build_classifier(model)
  1159. if with_ckpt:
  1160. model.load_state_dict(CheckpointLoader.load_checkpoint(ckpt)['state_dict'])
  1161. if remove_head:
  1162. print(model.head.fc, flush=True)
  1163. model.head.fc = nn.Identity() # 2048->1000
  1164. elif only_backbone:
  1165. def forward_backbone(self: ImageClassifier, x):
  1166. x = self.backbone(x)[-1]
  1167. return x
  1168. model.forward = partial(forward_backbone, model)
  1169. return model
  1170. @staticmethod
  1171. def build_replknet31b_mmpretrain(with_ckpt=False, remove_head=False, only_backbone=False, scale="31b", size=224):
  1172. print("replknet31b ================================", flush=True)
  1173. from mmengine.runner import CheckpointLoader
  1174. from mmpretrain.models import build_classifier, ImageClassifier
  1175. replknet31b = dict(
  1176. type='ImageClassifier',
  1177. backbone=dict(
  1178. type='RepLKNet',
  1179. arch='31B',
  1180. out_indices=(3, ),
  1181. ),
  1182. neck=dict(type='GlobalAveragePooling'),
  1183. head=dict(
  1184. type='LinearClsHead',
  1185. num_classes=1000,
  1186. in_channels=1024,
  1187. loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
  1188. topk=(1, 5),
  1189. ))
  1190. ckpt = "https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth"
  1191. model = build_classifier(replknet31b)
  1192. if with_ckpt:
  1193. model.load_state_dict(CheckpointLoader.load_checkpoint(ckpt)['state_dict'])
  1194. if remove_head:
  1195. print(model.head.fc, flush=True)
  1196. model.head.fc = nn.Identity()
  1197. elif only_backbone:
  1198. def forward_backbone(self: ImageClassifier, x):
  1199. x = self.backbone(x)[-1]
  1200. return x
  1201. model.forward = partial(forward_backbone, model)
  1202. return model
  1203. @staticmethod
  1204. def build_mmpretrain_models(cfg="swin_tiny", ckpt=True, only_backbone=False, with_norm=True, **kwargs):
  1205. import os
  1206. from functools import partial
  1207. from mmengine.runner import CheckpointLoader
  1208. from mmpretrain.models import build_classifier, ImageClassifier, ConvNeXt, VisionTransformer, SwinTransformer
  1209. from mmengine.config import Config
  1210. config_root = os.path.join(os.path.dirname(__file__), "../../analyze/mmpretrain_configs/configs/")
  1211. CFGS = dict(
  1212. swin_tiny=dict(
  1213. model=Config.fromfile(os.path.join(config_root, "./swin_transformer/swin-tiny_16xb64_in1k.py")).to_dict()['model'],
  1214. ckpt="https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
  1215. ),
  1216. convnext_tiny=dict(
  1217. model=Config.fromfile(os.path.join(config_root, "./convnext/convnext-tiny_32xb128_in1k.py")).to_dict()['model'],
  1218. ckpt="https://download.openmmlab.com/mmclassification/v0/convnext/convnext-tiny_32xb128_in1k_20221207-998cf3e9.pth",
  1219. ),
  1220. deit_small=dict(
  1221. model=Config.fromfile(os.path.join(config_root, "./deit/deit-small_4xb256_in1k.py")).to_dict()['model'],
  1222. ckpt="https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth",
  1223. ),
  1224. resnet50=dict(
  1225. model=Config.fromfile(os.path.join(config_root, "./resnet/resnet50_8xb32_in1k.py")).to_dict()['model'],
  1226. ckpt="https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
  1227. ),
  1228. # ================================
  1229. swin_small=dict(
  1230. model=Config.fromfile(os.path.join(config_root, "./swin_transformer/swin-small_16xb64_in1k.py")).to_dict()['model'],
  1231. ckpt="https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth",
  1232. ),
  1233. convnext_small=dict(
  1234. model=Config.fromfile(os.path.join(config_root, "./convnext/convnext-small_32xb128_in1k.py")).to_dict()['model'],
  1235. ckpt="https://download.openmmlab.com/mmclassification/v0/convnext/convnext-small_32xb128_in1k_20221207-4ab7052c.pth",
  1236. ),
  1237. deit_base=dict(
  1238. model=Config.fromfile(os.path.join(config_root, "./deit/deit-base_16xb64_in1k.py")).to_dict()['model'],
  1239. ckpt="https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth",
  1240. ),
  1241. resnet101=dict(
  1242. model=Config.fromfile(os.path.join(config_root, "./resnet/resnet101_8xb32_in1k.py")).to_dict()['model'],
  1243. ckpt="https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_8xb32_in1k_20210831-539c63f8.pth",
  1244. ),
  1245. # ================================
  1246. swin_base=dict(
  1247. model=Config.fromfile(os.path.join(config_root, "./swin_transformer/swin-base_16xb64_in1k.py")).to_dict()['model'],
  1248. ckpt="https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth",
  1249. ),
  1250. convnext_base=dict(
  1251. model=Config.fromfile(os.path.join(config_root, "./convnext/convnext-base_32xb128_in1k.py")).to_dict()['model'],
  1252. ckpt="https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_32xb128_in1k_20221207-fbdb5eb9.pth",
  1253. ),
  1254. replknet_base=dict(
  1255. # comment this "from mmpretrain.models import build_classifier" in __base__/models/replknet...
  1256. model=Config.fromfile(os.path.join(config_root, "./replknet/replknet-31B_32xb64_in1k.py")).to_dict()['model'],
  1257. ckpt="https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_3rdparty_in1k_20221118-fd08e268.pth",
  1258. ),
  1259. )
  1260. if cfg not in CFGS:
  1261. return None
  1262. model: ImageClassifier = build_classifier(CFGS[cfg]['model'])
  1263. if ckpt:
  1264. model.load_state_dict(CheckpointLoader.load_checkpoint(CFGS[cfg]['ckpt'])['state_dict'])
  1265. if only_backbone:
  1266. if isinstance(model.backbone, ConvNeXt):
  1267. model.backbone.gap_before_final_norm = False
  1268. if isinstance(model.backbone, VisionTransformer):
  1269. model.backbone.out_type = 'featmap'
  1270. def forward_backbone(self: ImageClassifier, x):
  1271. x = self.backbone(x)[-1]
  1272. return x
  1273. if not with_norm:
  1274. setattr(model, f"norm{model.backbone.out_indices[-1]}", lambda x: x)
  1275. model.forward = partial(forward_backbone, model)
  1276. return model
  1277. @classmethod
  1278. def check(cls):
  1279. for mbuild in [
  1280. # partial(cls.build_vmamba, scale="tv0"),
  1281. # partial(cls.build_vmamba, scale="tv1"),
  1282. # partial(cls.build_vmamba, scale="tv2"),
  1283. # partial(cls.build_vmamba, scale="sv0"),
  1284. # partial(cls.build_vmamba, scale="sv2"),
  1285. # partial(cls.build_vmamba, scale="bv0"),
  1286. # partial(cls.build_vmamba, scale="bv2"),
  1287. # partial(cls.build_swin, scale="tiny"),
  1288. # partial(cls.build_swin, scale="small"),
  1289. # partial(cls.build_swin, scale="base"),
  1290. # partial(cls.build_convnext, scale="tiny"),
  1291. # partial(cls.build_convnext, scale="small"),
  1292. # partial(cls.build_convnext, scale="base"),
  1293. # partial(cls.build_hivit, scale="tiny"),
  1294. # partial(cls.build_hivit, scale="small"),
  1295. # partial(cls.build_hivit, scale="base"),
  1296. # partial(cls.build_intern, scale="tiny"),
  1297. # partial(cls.build_intern, scale="small"),
  1298. # partial(cls.build_intern, scale="base"),
  1299. # partial(cls.build_xcit, scale="tiny"),
  1300. # partial(cls.build_xcit, scale="small"),
  1301. # partial(cls.build_xcit, scale="base"),
  1302. # partial(cls.build_swin_mmpretrain, scale="tiny"),
  1303. # partial(cls.build_swin_mmpretrain, scale="small"),
  1304. # partial(cls.build_swin_mmpretrain, scale="base"),
  1305. # partial(cls.build_hivit_mmpretrain, scale="tiny"),
  1306. # partial(cls.build_hivit_mmpretrain, scale="small"),
  1307. # partial(cls.build_hivit_mmpretrain, scale="base"),
  1308. # partial(cls.build_deit_mmpretrain, scale="small"),
  1309. # partial(cls.build_deit_mmpretrain, scale="base"),
  1310. # partial(cls.build_resnet_mmpretrain, scale="r50"),
  1311. # partial(cls.build_resnet_mmpretrain, scale="r101"),
  1312. # partial(cls.build_replknet31b_mmpretrain, scale="31b"),
  1313. ]:
  1314. for size in [224, 768]:
  1315. inp = torch.randn((2, 3, size, size)).cuda()
  1316. for with_ckpt in [False, True]:
  1317. for remove_head in [False, True]:
  1318. for only_backbone in [False, True]:
  1319. if False:
  1320. model = mbuild(with_ckpt=with_ckpt, remove_head=remove_head, only_backbone=only_backbone, size=size).cuda()
  1321. print(size, with_ckpt, remove_head, only_backbone, model(inp).shape, flush=True)
  1322. try:
  1323. model = mbuild(with_ckpt=with_ckpt, remove_head=remove_head, only_backbone=only_backbone).cuda()
  1324. print(size, with_ckpt, remove_head, only_backbone, model(inp).shape, flush=True)
  1325. except Exception as e:
  1326. print(size, with_ckpt, remove_head, only_backbone, flush=True)
  1327. print("ERROR:", e, flush=True)
  1328. breakpoint()
  1329. # used for print flops
  1330. class FLOPs:
  1331. @staticmethod
  1332. def register_supported_ops():
  1333. build = import_abspy("models", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../classification/"))
  1334. selective_scan_flop_jit: Callable = build.vmamba.selective_scan_flop_jit
  1335. # flops_selective_scan_fn: Callable = build.vmamba.flops_selective_scan_fn
  1336. # flops_selective_scan_ref: Callable = build.vmamba.flops_selective_scan_ref
  1337. def causal_conv_1d_jit(inputs, outputs):
  1338. """
  1339. https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
  1340. x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) out: (batch, dim, seqlen)
  1341. out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
  1342. """
  1343. from fvcore.nn.jit_handles import conv_flop_jit
  1344. return conv_flop_jit(inputs, outputs)
  1345. supported_ops={
  1346. "aten::gelu": None, # as relu is in _IGNORED_OPS
  1347. "aten::silu": None, # as relu is in _IGNORED_OPS
  1348. "aten::neg": None, # as relu is in _IGNORED_OPS
  1349. "aten::exp": None, # as relu is in _IGNORED_OPS
  1350. "aten::flip": None, # as permute is in _IGNORED_OPS
  1351. # =====================================================
  1352. # for mamba-ssm
  1353. "prim::PythonOp.CausalConv1dFn": causal_conv_1d_jit,
  1354. "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit,
  1355. # =====================================================
  1356. # for VMamba
  1357. "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
  1358. "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
  1359. # "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
  1360. "prim::PythonOp.SelectiveScan": selective_scan_flop_jit,
  1361. "prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
  1362. # =====================================================
  1363. # "aten::scaled_dot_product_attention": ...
  1364. }
  1365. return supported_ops
  1366. @staticmethod
  1367. def check_operations(model: nn.Module, inputs=None, input_shape=(3, 224, 224)):
  1368. from fvcore.nn.jit_analysis import _get_scoped_trace_graph, _named_modules_with_dup, Counter, JitModelAnalysis
  1369. if inputs is None:
  1370. assert input_shape is not None
  1371. if len(input_shape) == 1:
  1372. input_shape = (1, 3, input_shape[0], input_shape[0])
  1373. elif len(input_shape) == 2:
  1374. input_shape = (1, 3, *input_shape)
  1375. elif len(input_shape) == 3:
  1376. input_shape = (1, *input_shape)
  1377. else:
  1378. assert len(input_shape) == 4
  1379. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  1380. model.eval()
  1381. flop_counter = JitModelAnalysis(model, inputs)
  1382. flop_counter._ignored_ops = set()
  1383. flop_counter._op_handles = dict()
  1384. assert flop_counter.total() == 0 # make sure no operations supported
  1385. print(flop_counter.unsupported_ops(), flush=True)
  1386. print(f"supported ops {flop_counter._op_handles}; ignore ops {flop_counter._ignored_ops};", flush=True)
  1387. @classmethod
  1388. def fvcore_flop_count(cls, model: nn.Module, inputs=None, input_shape=(3, 224, 224), show_table=False, show_arch=False, verbose=True):
  1389. supported_ops = cls.register_supported_ops()
  1390. from fvcore.nn.parameter_count import parameter_count as fvcore_parameter_count
  1391. from fvcore.nn.flop_count import flop_count, FlopCountAnalysis, _DEFAULT_SUPPORTED_OPS
  1392. from fvcore.nn.print_model_statistics import flop_count_str, flop_count_table
  1393. from fvcore.nn.jit_analysis import _IGNORED_OPS
  1394. from fvcore.nn.jit_handles import get_shape, addmm_flop_jit
  1395. if inputs is None:
  1396. assert input_shape is not None
  1397. if len(input_shape) == 1:
  1398. input_shape = (1, 3, input_shape[0], input_shape[0])
  1399. elif len(input_shape) == 2:
  1400. input_shape = (1, 3, *input_shape)
  1401. elif len(input_shape) == 3:
  1402. input_shape = (1, *input_shape)
  1403. else:
  1404. assert len(input_shape) == 4
  1405. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  1406. model.eval()
  1407. Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops)
  1408. flops_table = flop_count_table(
  1409. flops = FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
  1410. max_depth=100,
  1411. activations=None,
  1412. show_param_shapes=True,
  1413. )
  1414. flops_str = flop_count_str(
  1415. flops = FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
  1416. activations=None,
  1417. )
  1418. if show_arch:
  1419. print(flops_str)
  1420. if show_table:
  1421. print(flops_table)
  1422. params = fvcore_parameter_count(model)[""]
  1423. flops = sum(Gflops.values())
  1424. if verbose:
  1425. print(Gflops.items())
  1426. print("GFlops: ", flops, "Params: ", params, flush=True)
  1427. return params, flops
  1428. # equals with fvcore_flop_count
  1429. @classmethod
  1430. def mmengine_flop_count(cls, model: nn.Module = None, input_shape = (3, 224, 224), show_table=False, show_arch=False, _get_model_complexity_info=False):
  1431. supported_ops = cls.register_supported_ops()
  1432. from mmengine.analysis.print_helper import is_tuple_of, FlopAnalyzer, ActivationAnalyzer, parameter_count, _format_size, complexity_stats_table, complexity_stats_str
  1433. from mmengine.analysis.jit_analysis import _IGNORED_OPS
  1434. from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS, _DEFAULT_SUPPORTED_ACT_OPS
  1435. from mmengine.analysis import get_model_complexity_info as mm_get_model_complexity_info
  1436. # modified from mmengine.analysis
  1437. def get_model_complexity_info(
  1438. model: nn.Module,
  1439. input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
  1440. None] = None,
  1441. inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
  1442. None] = None,
  1443. show_table: bool = True,
  1444. show_arch: bool = True,
  1445. ):
  1446. if input_shape is None and inputs is None:
  1447. raise ValueError('One of "input_shape" and "inputs" should be set.')
  1448. elif input_shape is not None and inputs is not None:
  1449. raise ValueError('"input_shape" and "inputs" cannot be both set.')
  1450. if inputs is None:
  1451. device = next(model.parameters()).device
  1452. if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
  1453. inputs = (torch.randn(1, *input_shape).to(device), )
  1454. elif is_tuple_of(input_shape, tuple) and all([
  1455. is_tuple_of(one_input_shape, int)
  1456. for one_input_shape in input_shape # type: ignore
  1457. ]): # tuple of tuple of int, construct multiple tensors
  1458. inputs = tuple([
  1459. torch.randn(1, *one_input_shape).to(device)
  1460. for one_input_shape in input_shape # type: ignore
  1461. ])
  1462. else:
  1463. raise ValueError(
  1464. '"input_shape" should be either a `tuple of int` (to construct'
  1465. 'one input tensor) or a `tuple of tuple of int` (to construct'
  1466. 'multiple input tensors).')
  1467. flop_handler = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops)
  1468. # activation_handler = ActivationAnalyzer(model, inputs)
  1469. flops = flop_handler.total()
  1470. # activations = activation_handler.total()
  1471. params = parameter_count(model)['']
  1472. flops_str = _format_size(flops)
  1473. # activations_str = _format_size(activations)
  1474. params_str = _format_size(params)
  1475. if show_table:
  1476. complexity_table = complexity_stats_table(
  1477. flops=flop_handler,
  1478. # activations=activation_handler,
  1479. show_param_shapes=True,
  1480. )
  1481. complexity_table = '\n' + complexity_table
  1482. else:
  1483. complexity_table = ''
  1484. if show_arch:
  1485. complexity_arch = complexity_stats_str(
  1486. flops=flop_handler,
  1487. # activations=activation_handler,
  1488. )
  1489. complexity_arch = '\n' + complexity_arch
  1490. else:
  1491. complexity_arch = ''
  1492. return {
  1493. 'flops': flops,
  1494. 'flops_str': flops_str,
  1495. # 'activations': activations,
  1496. # 'activations_str': activations_str,
  1497. 'params': params,
  1498. 'params_str': params_str,
  1499. 'out_table': complexity_table,
  1500. 'out_arch': complexity_arch
  1501. }
  1502. if _get_model_complexity_info:
  1503. return get_model_complexity_info
  1504. model.eval()
  1505. analysis_results = get_model_complexity_info(
  1506. model,
  1507. input_shape,
  1508. show_table=show_table,
  1509. show_arch=show_arch,
  1510. )
  1511. flops = analysis_results['flops_str']
  1512. params = analysis_results['params_str']
  1513. # activations = analysis_results['activations_str']
  1514. out_table = analysis_results['out_table']
  1515. out_arch = analysis_results['out_arch']
  1516. if show_arch:
  1517. print(out_arch)
  1518. if show_table:
  1519. print(out_table)
  1520. split_line = '=' * 30
  1521. print(f'{split_line}\nInput shape: {input_shape}\t'
  1522. f'Flops: {flops}\tParams: {params}\t'
  1523. # f'Activation: {activations}\n{split_line}'
  1524. , flush=True)
  1525. # print('!!!Only the backbone network is counted in FLOPs analysis.')
  1526. # print('!!!Please be cautious if you use the results in papers. '
  1527. # 'You may need to check if all ops are supported and verify that the '
  1528. # 'flops computation is correct.')
  1529. @classmethod
  1530. def mmdet_flops(cls, config=None, extra_config=None):
  1531. from mmengine.config import Config
  1532. from mmengine.runner import Runner
  1533. import numpy as np
  1534. import os
  1535. cfg = Config.fromfile(config)
  1536. if "model" in cfg:
  1537. if "pretrained" in cfg["model"]:
  1538. cfg["model"].pop("pretrained")
  1539. if extra_config is not None:
  1540. new_cfg = Config.fromfile(extra_config)
  1541. new_cfg["model"] = cfg["model"]
  1542. cfg = new_cfg
  1543. cfg["work_dir"] = "/tmp"
  1544. cfg["default_scope"] = "mmdet"
  1545. runner = Runner.from_cfg(cfg)
  1546. model = runner.model.cuda()
  1547. get_model_complexity_info = cls.mmengine_flop_count(_get_model_complexity_info=True)
  1548. if True:
  1549. oridir = os.getcwd()
  1550. os.chdir(os.path.join(os.path.dirname(__file__), "../detection"))
  1551. data_loader = runner.val_dataloader
  1552. num_images = 100
  1553. mean_flops = []
  1554. for idx, data_batch in enumerate(data_loader):
  1555. if idx == num_images:
  1556. break
  1557. data = model.data_preprocessor(data_batch)
  1558. model.forward = partial(model.forward, data_samples=data['data_samples'])
  1559. # out = get_model_complexity_info(model, inputs=data['inputs'])
  1560. out = get_model_complexity_info(model, input_shape=(3, 1280, 800))
  1561. params = out['params_str']
  1562. mean_flops.append(out['flops'])
  1563. mean_flops = np.average(np.array(mean_flops))
  1564. print(params, mean_flops)
  1565. os.chdir(oridir)
  1566. @classmethod
  1567. def mmseg_flops(cls, config=None, input_shape=(3, 512, 2048)):
  1568. from mmengine.config import Config
  1569. from mmengine.runner import Runner
  1570. cfg = Config.fromfile(config)
  1571. cfg["work_dir"] = "/tmp"
  1572. cfg["default_scope"] = "mmseg"
  1573. runner = Runner.from_cfg(cfg)
  1574. model = runner.model.cuda()
  1575. cls.fvcore_flop_count(model, input_shape=input_shape)
  1576. if __name__ == "__main__":
  1577. BuildModels.check()