benchmark.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import time
  5. import numpy as np
  6. import torch
  7. from mmengine import Config
  8. from mmengine.fileio import dump
  9. from mmengine.model.utils import revert_sync_batchnorm
  10. from mmengine.registry import init_default_scope
  11. from mmengine.runner import Runner, load_checkpoint
  12. from mmengine.utils import mkdir_or_exist
  13. from mmseg.registry import MODELS
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
  16. parser.add_argument('config', help='test config file path')
  17. parser.add_argument('checkpoint', help='checkpoint file')
  18. parser.add_argument(
  19. '--log-interval', type=int, default=50, help='interval of logging')
  20. parser.add_argument(
  21. '--work-dir',
  22. help=('if specified, the results will be dumped '
  23. 'into the directory as json'))
  24. parser.add_argument('--repeat-times', type=int, default=1)
  25. args = parser.parse_args()
  26. return args
  27. def main():
  28. args = parse_args()
  29. cfg = Config.fromfile(args.config)
  30. init_default_scope(cfg.get('default_scope', 'mmseg'))
  31. timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
  32. if args.work_dir is not None:
  33. mkdir_or_exist(osp.abspath(args.work_dir))
  34. json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
  35. else:
  36. # use config filename as default work_dir if cfg.work_dir is None
  37. work_dir = osp.join('./work_dirs',
  38. osp.splitext(osp.basename(args.config))[0])
  39. mkdir_or_exist(osp.abspath(work_dir))
  40. json_file = osp.join(work_dir, f'fps_{timestamp}.json')
  41. repeat_times = args.repeat_times
  42. # set cudnn_benchmark
  43. torch.backends.cudnn.benchmark = False
  44. cfg.model.pretrained = None
  45. benchmark_dict = dict(config=args.config, unit='img / s')
  46. overall_fps_list = []
  47. cfg.test_dataloader.batch_size = 1
  48. for time_index in range(repeat_times):
  49. print(f'Run {time_index + 1}:')
  50. # build the dataloader
  51. data_loader = Runner.build_dataloader(cfg.test_dataloader)
  52. # build the model and load checkpoint
  53. cfg.model.train_cfg = None
  54. model = MODELS.build(cfg.model)
  55. if 'checkpoint' in args and osp.exists(args.checkpoint):
  56. load_checkpoint(model, args.checkpoint, map_location='cpu')
  57. if torch.cuda.is_available():
  58. model = model.cuda()
  59. model = revert_sync_batchnorm(model)
  60. model.eval()
  61. # the first several iterations may be very slow so skip them
  62. num_warmup = 5
  63. pure_inf_time = 0
  64. total_iters = 200
  65. # benchmark with 200 batches and take the average
  66. for i, data in enumerate(data_loader):
  67. data = model.data_preprocessor(data, True)
  68. inputs = data['inputs']
  69. data_samples = data['data_samples']
  70. if torch.cuda.is_available():
  71. torch.cuda.synchronize()
  72. start_time = time.perf_counter()
  73. with torch.no_grad():
  74. model(inputs, data_samples, mode='predict')
  75. if torch.cuda.is_available():
  76. torch.cuda.synchronize()
  77. elapsed = time.perf_counter() - start_time
  78. if i >= num_warmup:
  79. pure_inf_time += elapsed
  80. if (i + 1) % args.log_interval == 0:
  81. fps = (i + 1 - num_warmup) / pure_inf_time
  82. print(f'Done image [{i + 1:<3}/ {total_iters}], '
  83. f'fps: {fps:.2f} img / s')
  84. if (i + 1) == total_iters:
  85. fps = (i + 1 - num_warmup) / pure_inf_time
  86. print(f'Overall fps: {fps:.2f} img / s\n')
  87. benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
  88. overall_fps_list.append(fps)
  89. break
  90. benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
  91. benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
  92. print(f'Average fps of {repeat_times} evaluations: '
  93. f'{benchmark_dict["average_fps"]}')
  94. print(f'The variance of {repeat_times} evaluations: '
  95. f'{benchmark_dict["fps_variance"]}')
  96. dump(benchmark_dict, json_file, indent=4)
  97. if __name__ == '__main__':
  98. main()