get_flops.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import tempfile
  4. from pathlib import Path
  5. import torch
  6. from mmengine import Config, DictAction
  7. from mmengine.logging import MMLogger
  8. from mmengine.model import revert_sync_batchnorm
  9. from mmengine.registry import init_default_scope
  10. from mmseg.models import BaseSegmentor
  11. from mmseg.registry import MODELS
  12. from mmseg.structures import SegDataSample
  13. try:
  14. from mmengine.analysis import get_model_complexity_info
  15. from mmengine.analysis.print_helper import _format_size
  16. except ImportError:
  17. raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')
  18. def parse_args():
  19. parser = argparse.ArgumentParser(
  20. description='Get the FLOPs of a segmentor')
  21. parser.add_argument('config', help='train config file path')
  22. parser.add_argument(
  23. '--shape',
  24. type=int,
  25. nargs='+',
  26. default=[2048, 1024],
  27. help='input image size')
  28. parser.add_argument(
  29. '--cfg-options',
  30. nargs='+',
  31. action=DictAction,
  32. help='override some settings in the used config, the key-value pair '
  33. 'in xxx=yyy format will be merged into config file. If the value to '
  34. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  35. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  36. 'Note that the quotation marks are necessary and that no white space '
  37. 'is allowed.')
  38. args = parser.parse_args()
  39. return args
  40. def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
  41. config_name = Path(args.config)
  42. if not config_name.exists():
  43. logger.error(f'Config file {config_name} does not exist')
  44. cfg: Config = Config.fromfile(config_name)
  45. cfg.work_dir = tempfile.TemporaryDirectory().name
  46. cfg.log_level = 'WARN'
  47. if args.cfg_options is not None:
  48. cfg.merge_from_dict(args.cfg_options)
  49. init_default_scope(cfg.get('scope', 'mmseg'))
  50. if len(args.shape) == 1:
  51. input_shape = (3, args.shape[0], args.shape[0])
  52. elif len(args.shape) == 2:
  53. input_shape = (3, ) + tuple(args.shape)
  54. else:
  55. raise ValueError('invalid input shape')
  56. result = {}
  57. model: BaseSegmentor = MODELS.build(cfg.model)
  58. if hasattr(model, 'auxiliary_head'):
  59. model.auxiliary_head = None
  60. if torch.cuda.is_available():
  61. model.cuda()
  62. model = revert_sync_batchnorm(model)
  63. result['ori_shape'] = input_shape[-2:]
  64. result['pad_shape'] = input_shape[-2:]
  65. data_batch = {
  66. 'inputs': [torch.rand(input_shape)],
  67. 'data_samples': [SegDataSample(metainfo=result)]
  68. }
  69. data = model.data_preprocessor(data_batch)
  70. model.eval()
  71. if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
  72. # TODO: Support MaskFormer and Mask2Former
  73. raise NotImplementedError('MaskFormer and Mask2Former are not '
  74. 'supported yet.')
  75. outputs = get_model_complexity_info(
  76. model,
  77. input_shape,
  78. inputs=data['inputs'],
  79. show_table=False,
  80. show_arch=False)
  81. result['flops'] = _format_size(outputs['flops'])
  82. result['params'] = _format_size(outputs['params'])
  83. result['compute_type'] = 'direct: randomly generate a picture'
  84. return result
  85. def main():
  86. args = parse_args()
  87. logger = MMLogger.get_instance(name='MMLogger')
  88. result = inference(args, logger)
  89. split_line = '=' * 30
  90. ori_shape = result['ori_shape']
  91. pad_shape = result['pad_shape']
  92. flops = result['flops']
  93. params = result['params']
  94. compute_type = result['compute_type']
  95. if pad_shape != ori_shape:
  96. print(f'{split_line}\nUse size divisor set input shape '
  97. f'from {ori_shape} to {pad_shape}')
  98. print(f'{split_line}\nCompute type: {compute_type}\n'
  99. f'Input shape: {pad_shape}\nFlops: {flops}\n'
  100. f'Params: {params}\n{split_line}')
  101. print('!!!Please be cautious if you use the results in papers. '
  102. 'You may need to check if all ops are supported and verify '
  103. 'that the flops computation is correct.')
  104. if __name__ == '__main__':
  105. main()