test.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. from mmengine.config import Config, DictAction
  6. from mmengine.runner import Runner
  7. import model
  8. # TODO: support fuse_conv_bn, visualization, and format_only
  9. def parse_args():
  10. parser = argparse.ArgumentParser(
  11. description='MMSeg test (and eval) a model')
  12. parser.add_argument('config', help='train config file path')
  13. parser.add_argument('checkpoint', help='checkpoint file')
  14. parser.add_argument(
  15. '--work-dir',
  16. help=('if specified, the evaluation metric results will be dumped'
  17. 'into the directory as json'))
  18. parser.add_argument(
  19. '--out',
  20. type=str,
  21. help='The directory to save output prediction for offline evaluation')
  22. parser.add_argument(
  23. '--show', action='store_true', help='show prediction results')
  24. parser.add_argument(
  25. '--show-dir',
  26. help='directory where painted images will be saved. '
  27. 'If specified, it will be automatically saved '
  28. 'to the work_dir/timestamp/show_dir')
  29. parser.add_argument(
  30. '--wait-time', type=float, default=2, help='the interval of show (s)')
  31. parser.add_argument(
  32. '--cfg-options',
  33. nargs='+',
  34. action=DictAction,
  35. help='override some settings in the used config, the key-value pair '
  36. 'in xxx=yyy format will be merged into config file. If the value to '
  37. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  38. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  39. 'Note that the quotation marks are necessary and that no white space '
  40. 'is allowed.')
  41. parser.add_argument(
  42. '--launcher',
  43. choices=['none', 'pytorch', 'slurm', 'mpi'],
  44. default='none',
  45. help='job launcher')
  46. parser.add_argument(
  47. '--tta', action='store_true', help='Test time augmentation')
  48. # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
  49. # will pass the `--local-rank` parameter to `tools/train.py` instead
  50. # of `--local_rank`.
  51. parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
  52. args = parser.parse_args()
  53. if 'LOCAL_RANK' not in os.environ:
  54. os.environ['LOCAL_RANK'] = str(args.local_rank)
  55. return args
  56. def trigger_visualization_hook(cfg, args):
  57. default_hooks = cfg.default_hooks
  58. if 'visualization' in default_hooks:
  59. visualization_hook = default_hooks['visualization']
  60. # Turn on visualization
  61. visualization_hook['draw'] = True
  62. if args.show:
  63. visualization_hook['show'] = True
  64. visualization_hook['wait_time'] = args.wait_time
  65. if args.show_dir:
  66. visualizer = cfg.visualizer
  67. visualizer['save_dir'] = args.show_dir
  68. else:
  69. raise RuntimeError(
  70. 'VisualizationHook must be included in default_hooks.'
  71. 'refer to usage '
  72. '"visualization=dict(type=\'VisualizationHook\')"')
  73. return cfg
  74. def main():
  75. args = parse_args()
  76. # load config
  77. cfg = Config.fromfile(args.config)
  78. cfg.launcher = args.launcher
  79. if args.cfg_options is not None:
  80. cfg.merge_from_dict(args.cfg_options)
  81. # work_dir is determined in this priority: CLI > segment in file > filename
  82. if args.work_dir is not None:
  83. # update configs according to CLI args if args.work_dir is not None
  84. cfg.work_dir = args.work_dir
  85. elif cfg.get('work_dir', None) is None:
  86. # use config filename as default work_dir if cfg.work_dir is None
  87. cfg.work_dir = osp.join('./work_dirs',
  88. osp.splitext(osp.basename(args.config))[0])
  89. cfg.load_from = args.checkpoint
  90. if args.show or args.show_dir:
  91. cfg = trigger_visualization_hook(cfg, args)
  92. if args.tta:
  93. cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline
  94. cfg.tta_model.module = cfg.model
  95. cfg.model = cfg.tta_model
  96. # add output_dir in metric
  97. if args.out is not None:
  98. cfg.test_evaluator['output_dir'] = args.out
  99. cfg.test_evaluator['keep_results'] = True
  100. # build the runner from config
  101. runner = Runner.from_cfg(cfg)
  102. # start testing
  103. runner.test()
  104. if __name__ == '__main__':
  105. main()