| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713 |
- import copy
- import os
- import sys
- import time
- import logging
- import shutil
- import argparse
- import torch
- from tensorboardX import SummaryWriter
- from typing import Callable
- from functools import partial
- from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from torchvision import datasets, transforms
- # from mmcv import Config
- # from mmcv.cnn.utils import get_model_complexity_info
- # from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
- def str2bool(v):
- if v.lower() in ('yes', 'true', 't', 'y', '1'):
- return True
- elif v.lower() in ('no', 'false', 'f', 'n', '0'):
- return False
- else:
- raise argparse.ArgumentTypeError('Unsupported value encountered.')
- def run_pre(cfg):
- # from time
- if cfg.sleep > -1:
- for i in range(cfg.sleep):
- time.sleep(1)
- print('\rCount down : {} s'.format(cfg.sleep - 1 - i), end='')
- # from memory
- elif cfg.memory > -1:
- s_times = 0
- while True:
- os.system('nvidia-smi -q -d Memory | grep -A4 GPU | grep Used > tmp')
- memory_used = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
- if memory_used[0] < 3000:
- os.system('rm tmp')
- break
- else:
- s_times += 1
- time.sleep(1)
- print('\rWaiting for {} s'.format(s_times), end='')
- def makedirs(dirs, exist_ok=False):
- if not isinstance(dirs, list):
- dirs = [dirs]
- for dir in dirs:
- os.makedirs(dir, exist_ok=exist_ok)
-
-
- def init_checkpoint(cfg):
- def rm_zero_size_file(path):
- files = os.listdir(path)
- for file in files:
- path = '{}/{}'.format(cfg.logdir, file)
- size = os.path.getsize(path) # unit:B
- if os.path.isfile(path) and size < 8:
- os.remove(path)
- os.makedirs(cfg.trainer.checkpoint, exist_ok=True)
- if cfg.trainer.resume_dir:
- cfg.logdir = '{}/{}'.format(cfg.trainer.checkpoint, cfg.trainer.resume_dir)
- checkpoint_path = cfg.model.model_kwargs['checkpoint_path']
- if checkpoint_path == '':
- cfg.model.model_kwargs['checkpoint_path'] = '{}/latest_ckpt.pth'.format(cfg.logdir)
- else:
- cfg.model.model_kwargs['checkpoint_path'] = '{}/{}'.format(cfg.logdir, checkpoint_path.split('/')[-1])
- state_dict = torch.load(cfg.model.model_kwargs['checkpoint_path'], map_location='cpu')
- cfg.trainer.iter, cfg.trainer.epoch = state_dict['iter'], state_dict['epoch']
- cfg.trainer.topk_recorder = state_dict['topk_recorder']
- else:
- if cfg.master:
- logdir = '{}_{}_{}_{}'.format(cfg.trainer.name, cfg.model.name, cfg.data.type, time.strftime("%Y%m%d-%H%M%S"))
- cfg.logdir = '{}/{}'.format(cfg.trainer.checkpoint, logdir)
- os.makedirs(cfg.logdir, exist_ok=True)
- shutil.copy('{}.py'.format('/'.join(cfg.cfg_path.split('.'))), '{}/{}.py'.format(cfg.logdir, cfg.cfg_path.split('.')[-1]))
- else:
- cfg.logdir = None
- cfg.trainer.iter, cfg.trainer.epoch = 0, 0
- cfg.trainer.topk_recorder = dict()
- cfg.trainer.topk_recorder = dict(net_top1=[], net_top5=[], net_E_top1=[], net_E_top5=[])
- cfg.logger = get_logger(cfg) if cfg.master else None
- cfg.writer = SummaryWriter(log_dir=cfg.logdir, comment='') if cfg.master else None
- log_msg(cfg.logger, f'==> Logging on master GPU: {cfg.logger_rank}')
- # rm_zero_size_file(cfg.logdir) if cfg.master else None
- def log_cfg(cfg):
-
- def _parse_Namespace(cfg, base_str=''):
- ret = {}
- if hasattr(cfg, '__dict__'):
- for key, val in cfg.__dict__.items():
- if not key.startswith('_'):
- ret.update(_parse_Namespace(val, '{}.{}'.format(base_str, key).lstrip('.')))
- else:
- ret.update({base_str:cfg})
- return ret
-
- cfg_dict = _parse_Namespace(cfg)
- key_max_length = max(list(map(len, cfg_dict.keys())))
- excludes = ['writer.', 'logger.handlers']
- exclude_keys = []
- for k, v in cfg_dict.items():
- for exclude in excludes:
- if k.find(exclude) != -1:
- exclude_keys.append(k) if k not in exclude_keys else None
- # cfg_str = '\n'.join(
- # [(('{' + ':<{}'.format(key_max_length) + '} : {' + ':<{}'.format(key_max_length)) + '}').format(k, str(v)) for
- # k, v in cfg_dict.items()])
- cfg_str = ''
- for k, v in cfg_dict.items():
- if k in exclude_keys:
- continue
- cfg_str += ('{' + ':<{}'.format(key_max_length) + '} : {' + ':<{}'.format(key_max_length) + '}').format(k, str(v))
- cfg_str += '\n'
- cfg_str = cfg_str.strip()
- cfg.cfg_dict, cfg.cfg_str = cfg_dict, cfg_str
- log_msg(cfg.logger, f'==> ********** cfg ********** \n{cfg.cfg_str}')
- def get_logger(cfg, mode='a+'):
- log_format = '%(asctime)s - %(message)s'
- logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
- fh = logging.FileHandler('{}/log_{}.txt'.format(cfg.logdir, cfg.mode), mode=mode)
- fh.setFormatter(logging.Formatter(log_format))
- logger = logging.getLogger()
- logger.addHandler(fh)
- cfg.logger = logger
- return logger
- def able(ret, mark=False, default=None):
- return ret if mark else default
- def log_msg(logger, msg, level='info'):
- if logger is not None:
- if msg is not None and level == 'info':
- logger.info(msg)
- class AvgMeter(object):
- def __init__(self, name, fmt=':f', show_name='val', add_name=''):
- self.name = name
- self.fmt = fmt
- self.show_name = show_name
- self.add_name = add_name
- self.reset()
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
- def __str__(self):
- fmtstr = '[{name} {' + self.show_name + self.fmt + '}'
- fmtstr += (' ({' + self.add_name + self.fmt + '})]' if self.add_name else ']')
- return fmtstr.format(**self.__dict__)
- class ProgressMeter(object):
- def __init__(self, meters, default_prefix=""):
- self.iter_fmtstr_iter = '{}: {:>3.2f}% [{}/{}]'
- self.iter_fmtstr_batch = ' [{:<.1f}/{:<3.1f}]'
- self.meters = meters
- self.default_prefix = default_prefix
- def get_msg(self, iter, iter_full, epoch=None, epoch_full=None, prefix=None):
- entries = [self.iter_fmtstr_iter.format(prefix if prefix else self.default_prefix, iter / iter_full * 100, iter, iter_full, epoch, epoch_full)]
- if epoch:
- entries += [self.iter_fmtstr_batch.format(epoch, epoch_full)]
- for meter in self.meters.values():
- entries.append(str(meter)) if meter.count > 0 else None
- return ' '.join(entries)
- def get_log_terms(log_terms, default_prefix=''):
- terms = {}
- for t in log_terms:
- t = {k: v for k, v in t.items()}
- t_name = t['name']
- terms[t_name] = AvgMeter(**t)
- progress = ProgressMeter(terms, default_prefix=default_prefix)
- return terms, progress
- def update_log_term(term, val, n, master):
- term.update(val, n) if term and master else None
- def accuracy(output, target, topk=(1,)):
- maxk = max(topk)
- batch_size = target.size(0)
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.reshape(1, -1).expand_as(pred))
- return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk], [correct[:k].reshape(-1).float().sum(0) for k in topk] + [batch_size]
- def get_timepc():
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- return time.perf_counter()
- def get_net_params(net):
- num_params = 0
- for param in net.parameters():
- if param.requires_grad:
- num_params += param.numel()
- return num_params / 1e6
- def import_abspy(name="models", path="classification/"):
- import sys
- import importlib
- path = os.path.abspath(path)
- assert os.path.isdir(path)
- sys.path.insert(0, path)
- module = importlib.import_module(name)
- sys.path.pop(0)
- return module
- # used for print flops
- class FLOPs:
- @staticmethod
- def register_supported_ops():
- build = import_abspy("lib_mamba", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../model"))
- selective_scan_flop_jit: Callable = build.vmamba.selective_scan_flop_jit
- flops_selective_scan_fn: Callable = build.csms6s.flops_selective_scan_fn
- flops_selective_scan_ref: Callable = build.csms6s.flops_selective_scan_ref
- supported_ops = {
- "aten::gelu": None, # as relu is in _IGNORED_OPS
- "aten::silu": None, # as relu is in _IGNORED_OPS
- "aten::neg": None, # as relu is in _IGNORED_OPS
- "aten::exp": None, # as relu is in _IGNORED_OPS
- "aten::flip": None, # as permute is in _IGNORED_OPS
- # "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, # latter
- # "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit, # latter
- # "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit, # latter
- # "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit, # latter
- # "prim::PythonOp.SelectiveScan": selective_scan_flop_jit, # latter
- # "prim::PythonOp.CrossScanTritonF": selective_scan_flop_jit, # latter
- "prim::PythonOp.SelectiveScanCuda": partial(selective_scan_flop_jit, backend="prefixsum", verbose=False),
- # "prim::PythonOp.CrossMergeTritonF": selective_scan_flop_jit, # latter
- # "aten::scaled_dot_product_attention": ...
- }
- return supported_ops
- @staticmethod
- def check_operations(model: torch.nn.Module, inputs=None, input_shape=(3, 224, 224)):
- from fvcore.nn.jit_analysis import _get_scoped_trace_graph, _named_modules_with_dup, Counter, JitModelAnalysis
- if inputs is None:
- assert input_shape is not None
- if len(input_shape) == 1:
- input_shape = (1, 3, input_shape[0], input_shape[0])
- elif len(input_shape) == 2:
- input_shape = (1, 3, *input_shape)
- elif len(input_shape) == 3:
- input_shape = (1, *input_shape)
- else:
- assert len(input_shape) == 4
- inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
- model.eval()
- flop_counter = JitModelAnalysis(model, inputs)
- flop_counter._ignored_ops = set()
- flop_counter._op_handles = dict()
- assert flop_counter.total() == 0 # make sure no operations supported
- print(flop_counter.unsupported_ops(), flush=True)
- print(f"supported ops {flop_counter._op_handles}; ignore ops {flop_counter._ignored_ops};", flush=True)
- @classmethod
- def fvcore_flop_count(cls, model: torch.nn.Module, inputs=None, input_shape=(3, 224, 224), show_table=True,
- show_arch=False, verbose=True):
- supported_ops = cls.register_supported_ops()
- from fvcore.nn.parameter_count import parameter_count as fvcore_parameter_count
- from fvcore.nn.flop_count import flop_count, FlopCountAnalysis, _DEFAULT_SUPPORTED_OPS
- from fvcore.nn.print_model_statistics import flop_count_str, flop_count_table
- from fvcore.nn.jit_analysis import _IGNORED_OPS
- from fvcore.nn.jit_handles import get_shape, addmm_flop_jit
- if inputs is None:
- assert input_shape is not None
- if len(input_shape) == 1:
- input_shape = (1, 3, input_shape[0], input_shape[0])
- elif len(input_shape) == 2:
- input_shape = (1, 3, *input_shape)
- elif len(input_shape) == 3:
- input_shape = (1, *input_shape)
- else:
- assert len(input_shape) == 4
- inputs = (torch.randn(input_shape).to(next(model.parameters()).device),)
- model.eval()
- Gflops, unsupported = flop_count(model=model, inputs=inputs, supported_ops=supported_ops)
- flops_table = flop_count_table(
- flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
- max_depth=100,
- activations=None,
- show_param_shapes=True,
- )
- flops_str = flop_count_str(
- flops=FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops),
- activations=None,
- )
- if show_arch:
- print(flops_str)
- if show_table:
- print(flops_table)
- params = fvcore_parameter_count(model)[""]
- flops = sum(Gflops.values())
- if verbose:
- print(Gflops.items())
- print("[GFlops: {:>6.3f}G]" "[Params: {:>6.3f}M]".format(flops, params / 1e6), flush=True)
- return params, flops
- def get_val_dataloader(batch_size=64, root="./val", img_size=224, sequential=True):
- import torch.utils.data
- size = int((224 / 224) * img_size)
- transform = transforms.Compose([
- transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
- transforms.CenterCrop((img_size, img_size)),
- transforms.ToTensor(),
- transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
- ])
- dataset = datasets.ImageFolder(root, transform=transform)
- if sequential:
- sampler = torch.utils.data.SequentialSampler(dataset)
- else:
- sampler = torch.utils.data.DistributedSampler(dataset)
- data_loader = torch.utils.data.DataLoader(
- dataset, sampler=sampler,
- batch_size=batch_size,
- shuffle=False,
- num_workers=0,
- pin_memory=True,
- drop_last=False
- )
- return data_loader
- class Throughput:
- # default no amp in testing tp
- # copied from swin_transformer
- @staticmethod
- @torch.no_grad()
- def throughput(data_loader, model, logger=logging):
- model.eval()
- for idx, (images, _) in enumerate(data_loader):
- images = images.cuda(non_blocking=True)
- batch_size = images.shape[0]
- for i in range(50):
- model(images)
- torch.cuda.synchronize()
- logger.info(f"throughput averaged with 30 times")
- torch.cuda.reset_peak_memory_stats()
- tic1 = time.time()
- for i in range(30):
- model(images)
- torch.cuda.synchronize()
- tic2 = time.time()
- logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
- logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
- return
- @staticmethod
- @torch.no_grad()
- def throughputamp(data_loader, model, logger=logging):
- model.eval()
- for idx, (images, _) in enumerate(data_loader):
- images = images.cuda(non_blocking=True)
- batch_size = images.shape[0]
- for i in range(50):
- with torch.cuda.amp.autocast():
- model(images)
- torch.cuda.synchronize()
- logger.info(f"throughput averaged with 30 times")
- torch.cuda.reset_peak_memory_stats()
- tic1 = time.time()
- for i in range(30):
- with torch.cuda.amp.autocast():
- model(images)
- torch.cuda.synchronize()
- tic2 = time.time()
- logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
- logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
- return
- @staticmethod
- def testfwdbwd(data_loader, model, logger, amp=True):
- model.cuda().train()
- criterion = torch.nn.CrossEntropyLoss()
- for idx, (images, targets) in enumerate(data_loader):
- images = images.cuda(non_blocking=True)
- targets = targets.cuda(non_blocking=True)
- batch_size = images.shape[0]
- for i in range(50):
- with torch.cuda.amp.autocast(enabled=amp):
- out = model(images)
- loss = criterion(out, targets)
- loss.backward()
- torch.cuda.synchronize()
- logger.info(f"testfwdbwd averaged with 30 times")
- torch.cuda.reset_peak_memory_stats()
- tic1 = time.time()
- for i in range(30):
- with torch.cuda.amp.autocast(enabled=amp):
- out = model(images)
- loss = criterion(out, targets)
- loss.backward()
- torch.cuda.synchronize()
- tic2 = time.time()
- logger.info(f"batch_size {batch_size} testfwdbwd {30 * batch_size / (tic2 - tic1)}")
- logger.info(f"batch_size {batch_size} mem cost {torch.cuda.max_memory_allocated() / 1024 / 1024} MB")
- return
- @classmethod
- def testall(cls, model, dataloader=None, data_path="", img_size=224, _batch_size=128, with_flops=False, inference_only=False):
- from fvcore.nn import parameter_count
- torch.cuda.empty_cache()
- model.cuda().eval()
- if with_flops:
- try:
- FLOPs.fvcore_flop_count(model, input_shape=(3, img_size, img_size), show_arch=False)
- except Exception as e:
- print("ERROR:", e, flush=True)
- # print(parameter_count(model)[""], sum(p.numel() for p in model.parameters() if p.requires_grad), flush=True)
- if dataloader is None:
- dataloader = get_val_dataloader(
- batch_size=_batch_size,
- root=os.path.join(os.path.abspath(data_path), "val"),
- img_size=img_size,
- )
- print('begin')
- cls.throughput(data_loader=dataloader, model=model, logger=logging)
- print("finished")
- if inference_only:
- return
- PASS = False
- batch_size = _batch_size
- while (not PASS) and (batch_size > 0):
- try:
- _dataloader = get_val_dataloader(
- batch_size=batch_size,
- root=os.path.join(os.path.abspath(data_path), "val"),
- img_size=img_size,
- )
- cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging)
- cls.testfwdbwd(data_loader=_dataloader, model=model, logger=logging, amp=False)
- PASS = True
- except Exception as e:
- print(e)
- batch_size = batch_size // 2
- print(f"batch_size {batch_size}", flush=True)
- # TIME_MIX_EXTRA_DIM = 32
- # TIME_DECAY_EXTRA_DIM = 64
- #
- # def vrwkv_flops(n, dim):
- # return n * dim * 29
- #
- # def vrwkv6_flops(n, dim, head_size):
- # addi_flops = 0
- # addi_flops += n * dim * (TIME_MIX_EXTRA_DIM * 10 + TIME_DECAY_EXTRA_DIM * 2 + 7 * head_size + 17)
- # addi_flops += n * (TIME_MIX_EXTRA_DIM * 5 + TIME_DECAY_EXTRA_DIM)
- # return addi_flops
- #
- # def get_addi_flops_vrwkv6(model, input_shape, cfg):
- # _, H, W = input_shape
- # try:
- # patch_size = cfg.model.backbone.patch_size
- # except:
- # patch_size = 16
- # h, w = H / patch_size, W / patch_size
- #
- # model_name = type(model.backbone).__name__
- # embed_dims = model.backbone.embed_dims
- # head_size = embed_dims // cfg.model.backbone.num_heads
- # print(f"Head Size in VRWKV6: {head_size}")
- # num_layers = len(model.backbone.layers)
- # addi_flops = 0
- # addi_flops += (num_layers * vrwkv6_flops(h * w, embed_dims, head_size))
- # print(f"Additional Flops in VRWKV6*{num_layers} layers: {flops_to_string(addi_flops)}")
- # return addi_flops
- #
- # def get_addi_flops_vrwkv(model, input_shape, cfg):
- # _, H, W = input_shape
- # try:
- # patch_size = cfg.model.backbone.patch_size
- # except:
- # patch_size = 16
- # h, w = H / patch_size, W / patch_size
- #
- # model_name = type(model.backbone).__name__
- # embed_dims = model.backbone.embed_dims
- # num_layers = len(model.backbone.layers)
- # addi_flops = 0
- # addi_flops += (num_layers * vrwkv_flops(h * w, embed_dims))
- # print(f"Additional Flops in VRWKV(Attn)*{num_layers} layers: {flops_to_string(addi_flops)}")
- # return addi_flops
- #
- # def get_flops(model, input_shape, cfg, ost):
- # flops, params = get_model_complexity_info(model, input_shape, as_strings=False, ost=ost)
- # model_name = type(model.backbone).__name__
- # if model_name == 'VRWKV':
- # add_flops = get_addi_flops_vrwkv(model, input_shape, cfg)
- # flops += add_flops
- # elif model_name == 'VRWKV6':
- # add_flops = get_addi_flops_vrwkv6(model, input_shape, cfg)
- # flops += add_flops
- # return flops_to_string(flops), params_to_string(params)
- # equals with fvcore_flop_count
- # @classmethod
- # def mmengine_flop_count(cls, model: nn.Module = None, input_shape=(3, 224, 224), show_table=False, show_arch=False,
- # _get_model_complexity_info=False):
- # supported_ops = cls.register_supported_ops()
- # from mmengine.analysis.print_helper import is_tuple_of, FlopAnalyzer, ActivationAnalyzer, parameter_count, \
- # _format_size, complexity_stats_table, complexity_stats_str
- # from mmengine.analysis.jit_analysis import _IGNORED_OPS
- # from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS, _DEFAULT_SUPPORTED_ACT_OPS
- # from mmengine.analysis import get_model_complexity_info as mm_get_model_complexity_info
- #
- # # modified from mmengine.analysis
- # def get_model_complexity_info(
- # model: nn.Module,
- # input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
- # None] = None,
- # inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
- # None] = None,
- # show_table: bool = True,
- # show_arch: bool = True,
- # ):
- # if input_shape is None and inputs is None:
- # raise ValueError('One of "input_shape" and "inputs" should be set.')
- # elif input_shape is not None and inputs is not None:
- # raise ValueError('"input_shape" and "inputs" cannot be both set.')
- #
- # if inputs is None:
- # device = next(model.parameters()).device
- # if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
- # inputs = (torch.randn(1, *input_shape).to(device),)
- # elif is_tuple_of(input_shape, tuple) and all([
- # is_tuple_of(one_input_shape, int)
- # for one_input_shape in input_shape # type: ignore
- # ]): # tuple of tuple of int, construct multiple tensors
- # inputs = tuple([
- # torch.randn(1, *one_input_shape).to(device)
- # for one_input_shape in input_shape # type: ignore
- # ])
- # else:
- # raise ValueError(
- # '"input_shape" should be either a `tuple of int` (to construct'
- # 'one input tensor) or a `tuple of tuple of int` (to construct'
- # 'multiple input tensors).')
- #
- # flop_handler = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops)
- # # activation_handler = ActivationAnalyzer(model, inputs)
- #
- # flops = flop_handler.total()
- # # activations = activation_handler.total()
- # params = parameter_count(model)['']
- #
- # flops_str = _format_size(flops)
- # # activations_str = _format_size(activations)
- # params_str = _format_size(params)
- #
- # if show_table:
- # complexity_table = complexity_stats_table(
- # flops=flop_handler,
- # # activations=activation_handler,
- # show_param_shapes=True,
- # )
- # complexity_table = '\n' + complexity_table
- # else:
- # complexity_table = ''
- #
- # if show_arch:
- # complexity_arch = complexity_stats_str(
- # flops=flop_handler,
- # # activations=activation_handler,
- # )
- # complexity_arch = '\n' + complexity_arch
- # else:
- # complexity_arch = ''
- #
- # return {
- # 'flops': flops,
- # 'flops_str': flops_str,
- # # 'activations': activations,
- # # 'activations_str': activations_str,
- # 'params': params,
- # 'params_str': params_str,
- # 'out_table': complexity_table,
- # 'out_arch': complexity_arch
- # }
- #
- # if _get_model_complexity_info:
- # return get_model_complexity_info
- #
- # model.eval()
- # analysis_results = get_model_complexity_info(
- # model,
- # input_shape,
- # show_table=show_table,
- # show_arch=show_arch,
- # )
- # flops = analysis_results['flops_str']
- # params = analysis_results['params_str']
- # # activations = analysis_results['activations_str']
- # out_table = analysis_results['out_table']
- # out_arch = analysis_results['out_arch']
- #
- # if show_arch:
- # print(out_arch)
- #
- # if show_table:
- # print(out_table)
- #
- # split_line = '=' * 30
- # print(f'{split_line}\nInput shape: {input_shape}\t'
- # f'Flops: {flops}\tParams: {params}\t'
- # # f'Activation: {activations}\n{split_line}'
- # , flush=True)
- #
- # # print('!!!Only the backbone network is counted in FLOPs analysis.')
- # # print('!!!Please be cautious if you use the results in papers. '
- # # 'You may need to check if all ops are supported and verify that the '
- # # 'flops computation is correct.')
- #
- # @classmethod
- # def mmdet_flops(cls, config=None, extra_config=None):
- # from mmengine.config import Config
- # from mmengine.runner import Runner
- # import numpy as np
- # import os
- #
- # cfg = Config.fromfile(config)
- # if "model" in cfg:
- # if "pretrained" in cfg["model"]:
- # cfg["model"].pop("pretrained")
- # if extra_config is not None:
- # new_cfg = Config.fromfile(extra_config)
- # new_cfg["model"] = cfg["model"]
- # cfg = new_cfg
- # cfg["work_dir"] = "/tmp"
- # cfg["default_scope"] = "mmdet"
- # runner = Runner.from_cfg(cfg)
- # model = runner.model.cuda()
- # get_model_complexity_info = cls.mmengine_flop_count(_get_model_complexity_info=True)
- #
- # if True:
- # oridir = os.getcwd()
- # os.chdir(os.path.join(os.path.dirname(__file__), "../detection"))
- # data_loader = runner.val_dataloader
- # num_images = 100
- # mean_flops = []
- # for idx, data_batch in enumerate(data_loader):
- # if idx == num_images:
- # break
- # data = model.data_preprocessor(data_batch)
- # model.forward = partial(model.forward, data_samples=data['data_samples'])
- # # out = get_model_complexity_info(model, inputs=data['inputs'])
- # out = get_model_complexity_info(model, input_shape=(3, 1280, 800))
- # params = out['params_str']
- # mean_flops.append(out['flops'])
- # mean_flops = np.average(np.array(mean_flops))
- # print(params, mean_flops)
- # os.chdir(oridir)
- #
- # @classmethod
- # def mmseg_flops(cls, config=None, input_shape=(3, 512, 2048)):
- # from mmengine.config import Config
- # from mmengine.runner import Runner
- #
- # cfg = Config.fromfile(config)
- # cfg["work_dir"] = "/tmp"
- # cfg["default_scope"] = "mmseg"
- # runner = Runner.from_cfg(cfg)
- # model = runner.model.cuda()
- #
- # cls.fvcore_flop_count(model, input_shape=input_shape)
|