pascal_context.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from functools import partial
  5. import numpy as np
  6. from detail import Detail
  7. from mmengine.utils import mkdir_or_exist, track_progress
  8. from PIL import Image
  9. _mapping = np.sort(
  10. np.array([
  11. 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
  12. 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
  13. 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
  14. 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
  15. ]))
  16. _key = np.array(range(len(_mapping))).astype('uint8')
  17. def generate_labels(img_id, detail, out_dir):
  18. def _class_to_index(mask, _mapping, _key):
  19. # assert the values
  20. values = np.unique(mask)
  21. for i in range(len(values)):
  22. assert (values[i] in _mapping)
  23. index = np.digitize(mask.ravel(), _mapping, right=True)
  24. return _key[index].reshape(mask.shape)
  25. mask = Image.fromarray(
  26. _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
  27. filename = img_id['file_name']
  28. mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
  29. return osp.splitext(osp.basename(filename))[0]
  30. def parse_args():
  31. parser = argparse.ArgumentParser(
  32. description='Convert PASCAL VOC annotations to mmsegmentation format')
  33. parser.add_argument('devkit_path', help='pascal voc devkit path')
  34. parser.add_argument('json_path', help='annoation json filepath')
  35. parser.add_argument('-o', '--out_dir', help='output path')
  36. args = parser.parse_args()
  37. return args
  38. def main():
  39. args = parse_args()
  40. devkit_path = args.devkit_path
  41. if args.out_dir is None:
  42. out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
  43. else:
  44. out_dir = args.out_dir
  45. json_path = args.json_path
  46. mkdir_or_exist(out_dir)
  47. img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
  48. train_detail = Detail(json_path, img_dir, 'train')
  49. train_ids = train_detail.getImgs()
  50. val_detail = Detail(json_path, img_dir, 'val')
  51. val_ids = val_detail.getImgs()
  52. mkdir_or_exist(
  53. osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
  54. train_list = track_progress(
  55. partial(generate_labels, detail=train_detail, out_dir=out_dir),
  56. train_ids)
  57. with open(
  58. osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
  59. 'train.txt'), 'w') as f:
  60. f.writelines(line + '\n' for line in sorted(train_list))
  61. val_list = track_progress(
  62. partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
  63. with open(
  64. osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
  65. 'val.txt'), 'w') as f:
  66. f.writelines(line + '\n' for line in sorted(val_list))
  67. print('Done!')
  68. if __name__ == '__main__':
  69. main()