get_flops.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import tempfile
  4. from functools import partial
  5. from pathlib import Path
  6. import numpy as np
  7. import torch
  8. from mmengine.config import Config, DictAction
  9. from mmengine.logging import MMLogger
  10. from mmengine.model import revert_sync_batchnorm
  11. from mmengine.registry import init_default_scope
  12. from mmengine.runner import Runner
  13. from mmengine.utils import digit_version
  14. from mmdet.registry import MODELS
  15. try:
  16. from mmengine.analysis import get_model_complexity_info
  17. from mmengine.analysis.print_helper import _format_size
  18. except ImportError:
  19. raise ImportError('Please upgrade mmengine >= 0.6.0')
  20. def parse_args():
  21. parser = argparse.ArgumentParser(description='Get a detector flops')
  22. parser.add_argument('config', help='train config file path')
  23. parser.add_argument(
  24. '--num-images',
  25. type=int,
  26. default=100,
  27. help='num images of calculate model flops')
  28. parser.add_argument(
  29. '--cfg-options',
  30. nargs='+',
  31. action=DictAction,
  32. help='override some settings in the used config, the key-value pair '
  33. 'in xxx=yyy format will be merged into config file. If the value to '
  34. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  35. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  36. 'Note that the quotation marks are necessary and that no white space '
  37. 'is allowed.')
  38. args = parser.parse_args()
  39. return args
  40. def inference(args, logger):
  41. if digit_version(torch.__version__) < digit_version('1.12'):
  42. logger.warning(
  43. 'Some config files, such as configs/yolact and configs/detectors,'
  44. 'may have compatibility issues with torch.jit when torch<1.12. '
  45. 'If you want to calculate flops for these models, '
  46. 'please make sure your pytorch version is >=1.12.')
  47. config_name = Path(args.config)
  48. if not config_name.exists():
  49. logger.error(f'{config_name} not found.')
  50. cfg = Config.fromfile(args.config)
  51. cfg.val_dataloader.batch_size = 1
  52. cfg.work_dir = tempfile.TemporaryDirectory().name
  53. if args.cfg_options is not None:
  54. cfg.merge_from_dict(args.cfg_options)
  55. init_default_scope(cfg.get('default_scope', 'mmdet'))
  56. # TODO: The following usage is temporary and not safe
  57. # use hard code to convert mmSyncBN to SyncBN. This is a known
  58. # bug in mmengine, mmSyncBN requires a distributed environment,
  59. # this question involves models like configs/strong_baselines
  60. if hasattr(cfg, 'head_norm_cfg'):
  61. cfg['head_norm_cfg'] = dict(type='SyncBN', requires_grad=True)
  62. cfg['model']['roi_head']['bbox_head']['norm_cfg'] = dict(
  63. type='SyncBN', requires_grad=True)
  64. cfg['model']['roi_head']['mask_head']['norm_cfg'] = dict(
  65. type='SyncBN', requires_grad=True)
  66. result = {}
  67. avg_flops = []
  68. data_loader = Runner.build_dataloader(cfg.val_dataloader)
  69. model = MODELS.build(cfg.model)
  70. if torch.cuda.is_available():
  71. model = model.cuda()
  72. model = revert_sync_batchnorm(model)
  73. model.eval()
  74. _forward = model.forward
  75. for idx, data_batch in enumerate(data_loader):
  76. if idx == args.num_images:
  77. break
  78. data = model.data_preprocessor(data_batch)
  79. result['ori_shape'] = data['data_samples'][0].ori_shape
  80. result['pad_shape'] = data['data_samples'][0].pad_shape
  81. if hasattr(data['data_samples'][0], 'batch_input_shape'):
  82. result['pad_shape'] = data['data_samples'][0].batch_input_shape
  83. model.forward = partial(_forward, data_samples=data['data_samples'])
  84. outputs = get_model_complexity_info(
  85. model,
  86. None,
  87. inputs=data['inputs'],
  88. show_table=False,
  89. show_arch=False)
  90. avg_flops.append(outputs['flops'])
  91. params = outputs['params']
  92. result['compute_type'] = 'dataloader: load a picture from the dataset'
  93. del data_loader
  94. mean_flops = _format_size(int(np.average(avg_flops)))
  95. params = _format_size(params)
  96. result['flops'] = mean_flops
  97. result['params'] = params
  98. return result
  99. def main():
  100. args = parse_args()
  101. logger = MMLogger.get_instance(name='MMLogger')
  102. result = inference(args, logger)
  103. split_line = '=' * 30
  104. ori_shape = result['ori_shape']
  105. pad_shape = result['pad_shape']
  106. flops = result['flops']
  107. params = result['params']
  108. compute_type = result['compute_type']
  109. if pad_shape != ori_shape:
  110. print(f'{split_line}\nUse size divisor set input shape '
  111. f'from {ori_shape} to {pad_shape}')
  112. print(f'{split_line}\nCompute type: {compute_type}\n'
  113. f'Input shape: {pad_shape}\nFlops: {flops}\n'
  114. f'Params: {params}\n{split_line}')
  115. print('!!!Please be cautious if you use the results in papers. '
  116. 'You may need to check if all ops are supported and verify '
  117. 'that the flops computation is correct.')
  118. if __name__ == '__main__':
  119. main()