fuse_results.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import argparse
  2. from mmengine.fileio import dump, load
  3. from mmengine.logging import print_log
  4. from mmengine.utils import ProgressBar
  5. from pycocotools.coco import COCO
  6. from pycocotools.cocoeval import COCOeval
  7. from mmdet.models.utils import weighted_boxes_fusion
  8. def parse_args():
  9. parser = argparse.ArgumentParser(description='Fusion image \
  10. prediction results using Weighted \
  11. Boxes Fusion from multiple models.')
  12. parser.add_argument(
  13. 'pred-results',
  14. type=str,
  15. nargs='+',
  16. help='files of prediction results \
  17. from multiple models, json format.')
  18. parser.add_argument('--annotation', type=str, help='annotation file path')
  19. parser.add_argument(
  20. '--weights',
  21. type=float,
  22. nargs='*',
  23. default=None,
  24. help='weights for each model, '
  25. 'remember to correspond to the above prediction path.')
  26. parser.add_argument(
  27. '--fusion-iou-thr',
  28. type=float,
  29. default=0.55,
  30. help='IoU value for boxes to be a match in wbf.')
  31. parser.add_argument(
  32. '--skip-box-thr',
  33. type=float,
  34. default=0.0,
  35. help='exclude boxes with score lower than this variable in wbf.')
  36. parser.add_argument(
  37. '--conf-type',
  38. type=str,
  39. default='avg',
  40. help='how to calculate confidence in weighted boxes in wbf.')
  41. parser.add_argument(
  42. '--eval-single',
  43. action='store_true',
  44. help='whether evaluate each single model result.')
  45. parser.add_argument(
  46. '--save-fusion-results',
  47. action='store_true',
  48. help='whether save fusion result')
  49. parser.add_argument(
  50. '--out-dir',
  51. type=str,
  52. default='outputs',
  53. help='Output directory of images or prediction results.')
  54. args = parser.parse_args()
  55. return args
  56. def main():
  57. args = parse_args()
  58. assert len(args.models_name) == len(args.pred_results), \
  59. 'the quantities of model names and prediction results are not equal'
  60. cocoGT = COCO(args.annotation)
  61. predicts_raw = []
  62. models_name = ['model_' + str(i) for i in range(len(args.pred_results))]
  63. for model_name, path in \
  64. zip(models_name, args.pred_results):
  65. pred = load(path)
  66. predicts_raw.append(pred)
  67. if args.eval_single:
  68. print_log(f'Evaluate {model_name}...')
  69. cocoDt = cocoGT.loadRes(pred)
  70. coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox')
  71. coco_eval.evaluate()
  72. coco_eval.accumulate()
  73. coco_eval.summarize()
  74. predict = {
  75. str(image_id): {
  76. 'bboxes_list': [[] for _ in range(len(predicts_raw))],
  77. 'scores_list': [[] for _ in range(len(predicts_raw))],
  78. 'labels_list': [[] for _ in range(len(predicts_raw))]
  79. }
  80. for image_id in cocoGT.getImgIds()
  81. }
  82. for i, pred_single in enumerate(predicts_raw):
  83. for pred in pred_single:
  84. p = predict[str(pred['image_id'])]
  85. p['bboxes_list'][i].append(pred['bbox'])
  86. p['scores_list'][i].append(pred['score'])
  87. p['labels_list'][i].append(pred['category_id'])
  88. result = []
  89. prog_bar = ProgressBar(len(predict))
  90. for image_id, res in predict.items():
  91. bboxes, scores, labels = weighted_boxes_fusion(
  92. res['bboxes_list'],
  93. res['scores_list'],
  94. res['labels_list'],
  95. weights=args.weights,
  96. iou_thr=args.fusion_iou_thr,
  97. skip_box_thr=args.skip_box_thr,
  98. conf_type=args.conf_type)
  99. for bbox, score, label in zip(bboxes, scores, labels):
  100. result.append({
  101. 'bbox': bbox.numpy().tolist(),
  102. 'category_id': int(label),
  103. 'image_id': int(image_id),
  104. 'score': float(score)
  105. })
  106. prog_bar.update()
  107. if args.save_fusion_results:
  108. out_file = args.out_dir + '/fusion_results.json'
  109. dump(result, file=out_file)
  110. print_log(
  111. f'Fusion results have been saved to {out_file}.', logger='current')
  112. print_log('Evaluate fusion results using wbf...')
  113. cocoDt = cocoGT.loadRes(result)
  114. coco_eval = COCOeval(cocoGT, cocoDt, iouType='bbox')
  115. coco_eval.evaluate()
  116. coco_eval.accumulate()
  117. coco_eval.summarize()
  118. if __name__ == '__main__':
  119. main()