util.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. import copy
  2. import os
  3. import sys
  4. import time
  5. import logging
  6. import shutil
  7. import argparse
  8. import torch
  9. from tensorboardX import SummaryWriter
  10. from typing import Callable
  11. from functools import partial
  12. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from torchvision import datasets, transforms
  14. # from mmcv import Config
  15. # from mmcv.cnn.utils import get_model_complexity_info
  16. # from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
  17. def str2bool(v):
  18. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  19. return True
  20. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  21. return False
  22. else:
  23. raise argparse.ArgumentTypeError('Unsupported value encountered.')
  24. def run_pre(cfg):
  25. # from time
  26. if cfg.sleep > -1:
  27. for i in range(cfg.sleep):
  28. time.sleep(1)
  29. print('\rCount down : {} s'.format(cfg.sleep - 1 - i), end='')
  30. # from memory
  31. elif cfg.memory > -1:
  32. s_times = 0
  33. while True:
  34. os.system('nvidia-smi -q -d Memory | grep -A4 GPU | grep Used > tmp')
  35. memory_used = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
  36. if memory_used[0] < 3000:
  37. os.system('rm tmp')
  38. break
  39. else:
  40. s_times += 1
  41. time.sleep(1)
  42. print('\rWaiting for {} s'.format(s_times), end='')
  43. def makedirs(dirs, exist_ok=False):
  44. if not isinstance(dirs, list):
  45. dirs = [dirs]
  46. for dir in dirs:
  47. os.makedirs(dir, exist_ok=exist_ok)
  48. def init_checkpoint(cfg):
  49. def rm_zero_size_file(path):
  50. files = os.listdir(path)
  51. for file in files:
  52. path = '{}/{}'.format(cfg.logdir, file)
  53. size = os.path.getsize(path) # unit:B
  54. if os.path.isfile(path) and size < 8:
  55. os.remove(path)
  56. os.makedirs(cfg.trainer.checkpoint, exist_ok=True)
  57. if cfg.trainer.resume_dir:
  58. cfg.logdir = '{}/{}'.format(cfg.trainer.checkpoint, cfg.trainer.resume_dir)
  59. checkpoint_path = cfg.model.model_kwargs['checkpoint_path']
  60. if checkpoint_path == '':
  61. cfg.model.model_kwargs['checkpoint_path'] = '{}/latest_ckpt.pth'.format(cfg.logdir)
  62. else:
  63. cfg.model.model_kwargs['checkpoint_path'] = '{}/{}'.format(cfg.logdir, checkpoint_path.split('/')[-1])
  64. state_dict = torch.load(cfg.model.model_kwargs['checkpoint_path'], map_location='cpu')
  65. cfg.trainer.iter, cfg.trainer.epoch = state_dict['iter'], state_dict['epoch']
  66. cfg.trainer.topk_recorder = state_dict['topk_recorder']
  67. else:
  68. if cfg.master:
  69. logdir = '{}_{}_{}_{}'.format(cfg.trainer.name, cfg.model.name, cfg.data.type, time.strftime("%Y%m%d-%H%M%S"))
  70. cfg.logdir = '{}/{}'.format(cfg.trainer.checkpoint, logdir)
  71. os.makedirs(cfg.logdir, exist_ok=True)
  72. shutil.copy('{}.py'.format('/'.join(cfg.cfg_path.split('.'))), '{}/{}.py'.format(cfg.logdir, cfg.cfg_path.split('.')[-1]))
  73. else:
  74. cfg.logdir = None
  75. cfg.trainer.iter, cfg.trainer.epoch = 0, 0
  76. cfg.trainer.topk_recorder = dict()
  77. cfg.trainer.topk_recorder = dict(net_top1=[], net_top5=[], net_E_top1=[], net_E_top5=[])
  78. cfg.logger = get_logger(cfg) if cfg.master else None
  79. cfg.writer = SummaryWriter(log_dir=cfg.logdir, comment='') if cfg.master else None
  80. log_msg(cfg.logger, f'==> Logging on master GPU: {cfg.logger_rank}')
  81. # rm_zero_size_file(cfg.logdir) if cfg.master else None
  82. def log_cfg(cfg):
  83. def _parse_Namespace(cfg, base_str=''):
  84. ret = {}
  85. if hasattr(cfg, '__dict__'):
  86. for key, val in cfg.__dict__.items():
  87. if not key.startswith('_'):
  88. ret.update(_parse_Namespace(val, '{}.{}'.format(base_str, key).lstrip('.')))
  89. else:
  90. ret.update({base_str:cfg})
  91. return ret
  92. cfg_dict = _parse_Namespace(cfg)
  93. key_max_length = max(list(map(len, cfg_dict.keys())))
  94. excludes = ['writer.', 'logger.handlers']
  95. exclude_keys = []
  96. for k, v in cfg_dict.items():
  97. for exclude in excludes:
  98. if k.find(exclude) != -1:
  99. exclude_keys.append(k) if k not in exclude_keys else None
  100. # cfg_str = '\n'.join(
  101. # [(('{' + ':<{}'.format(key_max_length) + '} : {' + ':<{}'.format(key_max_length)) + '}').format(k, str(v)) for
  102. # k, v in cfg_dict.items()])
  103. cfg_str = ''
  104. for k, v in cfg_dict.items():
  105. if k in exclude_keys:
  106. continue
  107. cfg_str += ('{' + ':<{}'.format(key_max_length) + '} : {' + ':<{}'.format(key_max_length) + '}').format(k, str(v))
  108. cfg_str += '\n'
  109. cfg_str = cfg_str.strip()
  110. cfg.cfg_dict, cfg.cfg_str = cfg_dict, cfg_str
  111. log_msg(cfg.logger, f'==> ********** cfg ********** \n{cfg.cfg_str}')
  112. def get_logger(cfg, mode='a+'):
  113. log_format = '%(asctime)s - %(message)s'
  114. logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
  115. fh = logging.FileHandler('{}/log_{}.txt'.format(cfg.logdir, cfg.mode), mode=mode)
  116. fh.setFormatter(logging.Formatter(log_format))
  117. logger = logging.getLogger()
  118. logger.addHandler(fh)
  119. cfg.logger = logger
  120. return logger
  121. def able(ret, mark=False, default=None):
  122. return ret if mark else default
  123. def log_msg(logger, msg, level='info'):
  124. if logger is not None:
  125. if msg is not None and level == 'info':
  126. logger.info(msg)
  127. class AvgMeter(object):
  128. def __init__(self, name, fmt=':f', show_name='val', add_name=''):
  129. self.name = name
  130. self.fmt = fmt
  131. self.show_name = show_name
  132. self.add_name = add_name
  133. self.reset()
  134. def reset(self):
  135. self.val = 0
  136. self.avg = 0
  137. self.sum = 0
  138. self.count = 0
  139. def update(self, val, n=1):
  140. self.val = val
  141. self.sum += val * n
  142. self.count += n
  143. self.avg = self.sum / self.count
  144. def __str__(self):
  145. fmtstr = '[{name} {' + self.show_name + self.fmt + '}'
  146. fmtstr += (' ({' + self.add_name + self.fmt + '})]' if self.add_name else ']')
  147. return fmtstr.format(**self.__dict__)
  148. class ProgressMeter(object):
  149. def __init__(self, meters, default_prefix=""):
  150. self.iter_fmtstr_iter = '{}: {:>3.2f}% [{}/{}]'
  151. self.iter_fmtstr_batch = ' [{:<.1f}/{:<3.1f}]'
  152. self.meters = meters
  153. self.default_prefix = default_prefix
  154. def get_msg(self, iter, iter_full, epoch=None, epoch_full=None, prefix=None):
  155. entries = [self.iter_fmtstr_iter.format(prefix if prefix else self.default_prefix, iter / iter_full * 100, iter, iter_full, epoch, epoch_full)]
  156. if epoch:
  157. entries += [self.iter_fmtstr_batch.format(epoch, epoch_full)]
  158. for meter in self.meters.values():
  159. entries.append(str(meter)) if meter.count > 0 else None
  160. return ' '.join(entries)
  161. def get_log_terms(log_terms, default_prefix=''):
  162. terms = {}
  163. for t in log_terms:
  164. t = {k: v for k, v in t.items()}
  165. t_name = t['name']
  166. terms[t_name] = AvgMeter(**t)
  167. progress = ProgressMeter(terms, default_prefix=default_prefix)
  168. return terms, progress
  169. def update_log_term(term, val, n, master):
  170. term.update(val, n) if term and master else None
  171. def accuracy(output, target, topk=(1,)):
  172. maxk = max(topk)
  173. batch_size = target.size(0)
  174. _, pred = output.topk(maxk, 1, True, True)
  175. pred = pred.t()
  176. correct = pred.eq(target.reshape(1, -1).expand_as(pred))
  177. return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk], [correct[:k].reshape(-1).float().sum(0) for k in topk] + [batch_size]
  178. def get_timepc():
  179. if torch.cuda.is_available():
  180. torch.cuda.synchronize()
  181. return time.perf_counter()
  182. def get_net_params(net):
  183. num_params = 0
  184. for param in net.parameters():
  185. if param.requires_grad:
  186. num_params += param.numel()
  187. return num_params / 1e6
  188. def import_abspy(name="models", path="classification/"):
  189. import sys
  190. import importlib
  191. path = os.path.abspath(path)
  192. assert os.path.isdir(path)
  193. sys.path.insert(0, path)
  194. module = importlib.import_module(name)
  195. sys.path.pop(0)
  196. return module
  197. # used for print flops
  198. class FLOPs:
  199. @staticmethod
  200. def register_supported_ops():
  201. build = import_abspy("lib_mamba", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../model"))
  202. selective_scan_flop_jit: Callable = build.vmamba.selective_scan_flop_jit
  203. flops_selective_scan_fn: Callable = build.csms6s.flops_selective_scan_fn
  204. flops_selective_scan_ref: Callable = build.csms6s.flops_selective_scan_ref
  205. supported_ops = {
  206. "aten::gelu": None, # as relu is in _IGNORED_OPS
  207. "aten::silu": None, # as relu is in _IGNORED_OPS
  208. "aten::neg": None, # as relu is in _IGNORED_OPS
  209. "aten::exp": None, # as relu is in _IGNORED_OPS
  210. "aten::flip": None, # as permute is in _IGNORED_OPS
  211. # "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, # latter
  212. # "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit, # latter
  213. # "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit, # latter
  214. # "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit, # latter
  215. # "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, # latter
  216. # "prim::PythonOp.CrossScanTritonF": selective_scan_flop_jit, # latter
  217. "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=False),
  218. # "prim::PythonOp.CrossMergeTritonF": selective_scan_flop_jit, # latter
  219. # "aten::scaled_dot_product_attention": ...
  220. }
  221. return supported_ops
  222. @staticmethod
  223. def check_operations(model: torch.nn.Module, inputs=None, input_shape=(3, 224, 224)):
  224. from fvcore.nn.jit_analysis import _get_scoped_trace_graph, _named_modules_with_dup, Counter, JitModelAnalysis
  225. if inputs is None:
  226. assert input_shape is not None
  227. if len(input_shape) == 1:
  228. input_shape = (1, 3, input_shape[0], input_shape[0])
  229. elif len(input_shape) == 2:
  230. input_shape = (1, 3, *input_shape)
  231. elif len(input_shape) == 3:
  232. input_shape = (1, *input_shape)
  233. else:
  234. assert len(input_shape) == 4
  235. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  236. model.eval()
  237. flop_counter = JitModelAnalysis(model, inputs)
  238. flop_counter._ignored_ops = set()
  239. flop_counter._op_handles = dict()
  240. assert flop_counter.total() == 0 # make sure no operations supported
  241. print(flop_counter.unsupported_ops(), flush=True)
  242. print(f"supported ops {flop_counter._op_handles}; ignore ops {flop_counter._ignored_ops};", flush=True)
  243. @classmethod
  244. def fvcore_flop_count(cls, model: torch.nn.Module, inputs=None, input_shape=(3, 224, 224), show_table=True,
  245. show_arch=False, verbose=True):
  246. supported_ops = cls.register_supported_ops()
  247. from fvcore.nn.parameter_count import parameter_count as fvcore_parameter_count
  248. from fvcore.nn.flop_count import flop_count, FlopCountAnalysis, _DEFAULT_SUPPORTED_OPS
  249. from fvcore.nn.print_model_statistics import flop_count_str, flop_count_table
  250. from fvcore.nn.jit_analysis import _IGNORED_OPS
  251. from fvcore.nn.jit_handles import get_shape, addmm_flop_jit
  252. if inputs is None:
  253. assert input_shape is not None
  254. if len(input_shape) == 1:
  255. input_shape = (1, 3, input_shape[0], input_shape[0])
  256. elif len(input_shape) == 2:
  257. input_shape = (1, 3, *input_shape)
  258. elif len(input_shape) == 3:
  259. input_shape = (1, *input_shape)
  260. else:
  261. assert len(input_shape) == 4
  262. inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
  263. model.eval()
  264. Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops)
  265. flops_table = flop_count_table(
  266. flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
  267. max_depth=100,
  268. activations=None,
  269. show_param_shapes=True,
  270. )
  271. flops_str = flop_count_str(
  272. flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
  273. activations=None,
  274. )
  275. if show_arch:
  276. print(flops_str)
  277. if show_table:
  278. print(flops_table)
  279. params = fvcore_parameter_count(model)[""]
  280. flops = sum(Gflops.values())
  281. if verbose:
  282. print(Gflops.items())
  283. print("[GFlops: {:>6.3f}G]" "[Params: {:>6.3f}M]".format(flops, params / 1e6), flush=True)
  284. return params, flops
  285. def get_val_dataloader(batch_size=64, root="./val", img_size=224, sequential=True):
  286. import torch.utils.data
  287. size = int((224 / 224) * img_size)
  288. transform = transforms.Compose([
  289. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  290. transforms.CenterCrop((img_size, img_size)),
  291. transforms.ToTensor(),
  292. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  293. ])
  294. dataset = datasets.ImageFolder(root, transform=transform)
  295. if sequential:
  296. sampler = torch.utils.data.SequentialSampler(dataset)
  297. else:
  298. sampler = torch.utils.data.DistributedSampler(dataset)
  299. data_loader = torch.utils.data.DataLoader(
  300. dataset, sampler=sampler,
  301. batch_size=batch_size,
  302. shuffle=False,
  303. num_workers=0,
  304. pin_memory=True,
  305. drop_last=False
  306. )
  307. return data_loader
  308. class Throughput:
  309. # default no amp in testing tp
  310. # copied from swin_transformer
  311. @staticmethod
  312. @torch.no_grad()
  313. def throughput(data_loader, model, logger=logging):
  314. model.eval()
  315. for idx, (images, _) in enumerate(data_loader):
  316. images = images.cuda(non_blocking=True)
  317. batch_size = images.shape[0]
  318. for i in range(50):
  319. model(images)
  320. torch.cuda.synchronize()
  321. logger.info(f"throughput averaged with 30 times")
  322. torch.cuda.reset_peak_memory_stats()
  323. tic1 = time.time()
  324. for i in range(30):
  325. model(images)
  326. torch.cuda.synchronize()
  327. tic2 = time.time()
  328. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  329. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  330. return
  331. @staticmethod
  332. @torch.no_grad()
  333. def throughputamp(data_loader, model, logger=logging):
  334. model.eval()
  335. for idx, (images, _) in enumerate(data_loader):
  336. images = images.cuda(non_blocking=True)
  337. batch_size = images.shape[0]
  338. for i in range(50):
  339. with torch.cuda.amp.autocast():
  340. model(images)
  341. torch.cuda.synchronize()
  342. logger.info(f"throughput averaged with 30 times")
  343. torch.cuda.reset_peak_memory_stats()
  344. tic1 = time.time()
  345. for i in range(30):
  346. with torch.cuda.amp.autocast():
  347. model(images)
  348. torch.cuda.synchronize()
  349. tic2 = time.time()
  350. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  351. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  352. return
  353. @staticmethod
  354. def testfwdbwd(data_loader, model, logger, amp=True):
  355. model.cuda().train()
  356. criterion = torch.nn.CrossEntropyLoss()
  357. for idx, (images, targets) in enumerate(data_loader):
  358. images = images.cuda(non_blocking=True)
  359. targets = targets.cuda(non_blocking=True)
  360. batch_size = images.shape[0]
  361. for i in range(50):
  362. with torch.cuda.amp.autocast(enabled=amp):
  363. out = model(images)
  364. loss = criterion(out, targets)
  365. loss.backward()
  366. torch.cuda.synchronize()
  367. logger.info(f"testfwdbwd averaged with 30 times")
  368. torch.cuda.reset_peak_memory_stats()
  369. tic1 = time.time()
  370. for i in range(30):
  371. with torch.cuda.amp.autocast(enabled=amp):
  372. out = model(images)
  373. loss = criterion(out, targets)
  374. loss.backward()
  375. torch.cuda.synchronize()
  376. tic2 = time.time()
  377. logger.info(f"batch_size {batch_size} testfwdbwd {30 * batch_size / (tic2 - tic1)}")
  378. logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
  379. return
  380. @classmethod
  381. def testall(cls, model, dataloader=None, data_path="", img_size=224, _batch_size=128, with_flops=False, inference_only=False):
  382. from fvcore.nn import parameter_count
  383. torch.cuda.empty_cache()
  384. model.cuda().eval()
  385. if with_flops:
  386. try:
  387. FLOPs.fvcore_flop_count(model, input_shape=(3, img_size, img_size), show_arch=False)
  388. except Exception as e:
  389. print("ERROR:", e, flush=True)
  390. # print(parameter_count(model)[""], sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True)
  391. if dataloader is None:
  392. dataloader = get_val_dataloader(
  393. batch_size=_batch_size,
  394. root=os.path.join(os.path.abspath(data_path), "val"),
  395. img_size=img_size,
  396. )
  397. print('begin')
  398. cls.throughput(data_loader=dataloader, model=model, logger=logging)
  399. print("finished")
  400. if inference_only:
  401. return
  402. PASS = False
  403. batch_size = _batch_size
  404. while (not PASS) and (batch_size > 0):
  405. try:
  406. _dataloader = get_val_dataloader(
  407. batch_size=batch_size,
  408. root=os.path.join(os.path.abspath(data_path), "val"),
  409. img_size=img_size,
  410. )
  411. cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging)
  412. cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging, amp=False)
  413. PASS = True
  414. except Exception as e:
  415. print(e)
  416. batch_size = batch_size // 2
  417. print(f"batch_size {batch_size}", flush=True)
  418. # TIME_MIX_EXTRA_DIM = 32
  419. # TIME_DECAY_EXTRA_DIM = 64
  420. #
  421. # def vrwkv_flops(n, dim):
  422. # return n * dim * 29
  423. #
  424. # def vrwkv6_flops(n, dim, head_size):
  425. # addi_flops = 0
  426. # addi_flops += n * dim * (TIME_MIX_EXTRA_DIM * 10 + TIME_DECAY_EXTRA_DIM * 2 + 7 * head_size + 17)
  427. # addi_flops += n * (TIME_MIX_EXTRA_DIM * 5 + TIME_DECAY_EXTRA_DIM)
  428. # return addi_flops
  429. #
  430. # def get_addi_flops_vrwkv6(model, input_shape, cfg):
  431. # _, H, W = input_shape
  432. # try:
  433. # patch_size = cfg.model.backbone.patch_size
  434. # except:
  435. # patch_size = 16
  436. # h, w = H / patch_size, W / patch_size
  437. #
  438. # model_name = type(model.backbone).__name__
  439. # embed_dims = model.backbone.embed_dims
  440. # head_size = embed_dims // cfg.model.backbone.num_heads
  441. # print(f"Head Size in VRWKV6: {head_size}")
  442. # num_layers = len(model.backbone.layers)
  443. # addi_flops = 0
  444. # addi_flops += (num_layers * vrwkv6_flops(h * w, embed_dims, head_size))
  445. # print(f"Additional Flops in VRWKV6*{num_layers} layers: {flops_to_string(addi_flops)}")
  446. # return addi_flops
  447. #
  448. # def get_addi_flops_vrwkv(model, input_shape, cfg):
  449. # _, H, W = input_shape
  450. # try:
  451. # patch_size = cfg.model.backbone.patch_size
  452. # except:
  453. # patch_size = 16
  454. # h, w = H / patch_size, W / patch_size
  455. #
  456. # model_name = type(model.backbone).__name__
  457. # embed_dims = model.backbone.embed_dims
  458. # num_layers = len(model.backbone.layers)
  459. # addi_flops = 0
  460. # addi_flops += (num_layers * vrwkv_flops(h * w, embed_dims))
  461. # print(f"Additional Flops in VRWKV(Attn)*{num_layers} layers: {flops_to_string(addi_flops)}")
  462. # return addi_flops
  463. #
  464. # def get_flops(model, input_shape, cfg, ost):
  465. # flops, params = get_model_complexity_info(model, input_shape, as_strings=False, ost=ost)
  466. # model_name = type(model.backbone).__name__
  467. # if model_name == 'VRWKV':
  468. # add_flops = get_addi_flops_vrwkv(model, input_shape, cfg)
  469. # flops += add_flops
  470. # elif model_name == 'VRWKV6':
  471. # add_flops = get_addi_flops_vrwkv6(model, input_shape, cfg)
  472. # flops += add_flops
  473. # return flops_to_string(flops), params_to_string(params)
  474. # equals with fvcore_flop_count
  475. # @classmethod
  476. # def mmengine_flop_count(cls, model: nn.Module = None, input_shape=(3, 224, 224), show_table=False, show_arch=False,
  477. # _get_model_complexity_info=False):
  478. # supported_ops = cls.register_supported_ops()
  479. # from mmengine.analysis.print_helper import is_tuple_of, FlopAnalyzer, ActivationAnalyzer, parameter_count, \
  480. # _format_size, complexity_stats_table, complexity_stats_str
  481. # from mmengine.analysis.jit_analysis import _IGNORED_OPS
  482. # from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS, _DEFAULT_SUPPORTED_ACT_OPS
  483. # from mmengine.analysis import get_model_complexity_info as mm_get_model_complexity_info
  484. #
  485. # # modified from mmengine.analysis
  486. # def get_model_complexity_info(
  487. # model: nn.Module,
  488. # input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
  489. # None] = None,
  490. # inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
  491. # None] = None,
  492. # show_table: bool = True,
  493. # show_arch: bool = True,
  494. # ):
  495. # if input_shape is None and inputs is None:
  496. # raise ValueError('One of "input_shape" and "inputs" should be set.')
  497. # elif input_shape is not None and inputs is not None:
  498. # raise ValueError('"input_shape" and "inputs" cannot be both set.')
  499. #
  500. # if inputs is None:
  501. # device = next(model.parameters()).device
  502. # if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
  503. # inputs = (torch.randn(1, *input_shape).to(device),)
  504. # elif is_tuple_of(input_shape, tuple) and all([
  505. # is_tuple_of(one_input_shape, int)
  506. # for one_input_shape in input_shape # type: ignore
  507. # ]): # tuple of tuple of int, construct multiple tensors
  508. # inputs = tuple([
  509. # torch.randn(1, *one_input_shape).to(device)
  510. # for one_input_shape in input_shape # type: ignore
  511. # ])
  512. # else:
  513. # raise ValueError(
  514. # '"input_shape" should be either a `tuple of int` (to construct'
  515. # 'one input tensor) or a `tuple of tuple of int` (to construct'
  516. # 'multiple input tensors).')
  517. #
  518. # flop_handler = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops)
  519. # # activation_handler = ActivationAnalyzer(model, inputs)
  520. #
  521. # flops = flop_handler.total()
  522. # # activations = activation_handler.total()
  523. # params = parameter_count(model)['']
  524. #
  525. # flops_str = _format_size(flops)
  526. # # activations_str = _format_size(activations)
  527. # params_str = _format_size(params)
  528. #
  529. # if show_table:
  530. # complexity_table = complexity_stats_table(
  531. # flops=flop_handler,
  532. # # activations=activation_handler,
  533. # show_param_shapes=True,
  534. # )
  535. # complexity_table = '\n' + complexity_table
  536. # else:
  537. # complexity_table = ''
  538. #
  539. # if show_arch:
  540. # complexity_arch = complexity_stats_str(
  541. # flops=flop_handler,
  542. # # activations=activation_handler,
  543. # )
  544. # complexity_arch = '\n' + complexity_arch
  545. # else:
  546. # complexity_arch = ''
  547. #
  548. # return {
  549. # 'flops': flops,
  550. # 'flops_str': flops_str,
  551. # # 'activations': activations,
  552. # # 'activations_str': activations_str,
  553. # 'params': params,
  554. # 'params_str': params_str,
  555. # 'out_table': complexity_table,
  556. # 'out_arch': complexity_arch
  557. # }
  558. #
  559. # if _get_model_complexity_info:
  560. # return get_model_complexity_info
  561. #
  562. # model.eval()
  563. # analysis_results = get_model_complexity_info(
  564. # model,
  565. # input_shape,
  566. # show_table=show_table,
  567. # show_arch=show_arch,
  568. # )
  569. # flops = analysis_results['flops_str']
  570. # params = analysis_results['params_str']
  571. # # activations = analysis_results['activations_str']
  572. # out_table = analysis_results['out_table']
  573. # out_arch = analysis_results['out_arch']
  574. #
  575. # if show_arch:
  576. # print(out_arch)
  577. #
  578. # if show_table:
  579. # print(out_table)
  580. #
  581. # split_line = '=' * 30
  582. # print(f'{split_line}\nInput shape: {input_shape}\t'
  583. # f'Flops: {flops}\tParams: {params}\t'
  584. # # f'Activation: {activations}\n{split_line}'
  585. # , flush=True)
  586. #
  587. # # print('!!!Only the backbone network is counted in FLOPs analysis.')
  588. # # print('!!!Please be cautious if you use the results in papers. '
  589. # # 'You may need to check if all ops are supported and verify that the '
  590. # # 'flops computation is correct.')
  591. #
  592. # @classmethod
  593. # def mmdet_flops(cls, config=None, extra_config=None):
  594. # from mmengine.config import Config
  595. # from mmengine.runner import Runner
  596. # import numpy as np
  597. # import os
  598. #
  599. # cfg = Config.fromfile(config)
  600. # if "model" in cfg:
  601. # if "pretrained" in cfg["model"]:
  602. # cfg["model"].pop("pretrained")
  603. # if extra_config is not None:
  604. # new_cfg = Config.fromfile(extra_config)
  605. # new_cfg["model"] = cfg["model"]
  606. # cfg = new_cfg
  607. # cfg["work_dir"] = "/tmp"
  608. # cfg["default_scope"] = "mmdet"
  609. # runner = Runner.from_cfg(cfg)
  610. # model = runner.model.cuda()
  611. # get_model_complexity_info = cls.mmengine_flop_count(_get_model_complexity_info=True)
  612. #
  613. # if True:
  614. # oridir = os.getcwd()
  615. # os.chdir(os.path.join(os.path.dirname(__file__), "../detection"))
  616. # data_loader = runner.val_dataloader
  617. # num_images = 100
  618. # mean_flops = []
  619. # for idx, data_batch in enumerate(data_loader):
  620. # if idx == num_images:
  621. # break
  622. # data = model.data_preprocessor(data_batch)
  623. # model.forward = partial(model.forward, data_samples=data['data_samples'])
  624. # # out = get_model_complexity_info(model, inputs=data['inputs'])
  625. # out = get_model_complexity_info(model, input_shape=(3, 1280, 800))
  626. # params = out['params_str']
  627. # mean_flops.append(out['flops'])
  628. # mean_flops = np.average(np.array(mean_flops))
  629. # print(params, mean_flops)
  630. # os.chdir(oridir)
  631. #
  632. # @classmethod
  633. # def mmseg_flops(cls, config=None, input_shape=(3, 512, 2048)):
  634. # from mmengine.config import Config
  635. # from mmengine.runner import Runner
  636. #
  637. # cfg = Config.fromfile(config)
  638. # cfg["work_dir"] = "/tmp"
  639. # cfg["default_scope"] = "mmseg"
  640. # runner = Runner.from_cfg(cfg)
  641. # model = runner.model.cuda()
  642. #
  643. # cls.fvcore_flop_count(model, input_shape=input_shape)