voc_aug.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. from functools import partial
  5. import numpy as np
  6. from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
  7. from PIL import Image
  8. from scipy.io import loadmat
  9. AUG_LEN = 10582
  10. def convert_mat(mat_file, in_dir, out_dir):
  11. data = loadmat(osp.join(in_dir, mat_file))
  12. mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
  13. seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
  14. Image.fromarray(mask).save(seg_filename, 'PNG')
  15. def generate_aug_list(merged_list, excluded_list):
  16. return list(set(merged_list) - set(excluded_list))
  17. def parse_args():
  18. parser = argparse.ArgumentParser(
  19. description='Convert PASCAL VOC annotations to mmsegmentation format')
  20. parser.add_argument('devkit_path', help='pascal voc devkit path')
  21. parser.add_argument('aug_path', help='pascal voc aug path')
  22. parser.add_argument('-o', '--out_dir', help='output path')
  23. parser.add_argument(
  24. '--nproc', default=1, type=int, help='number of process')
  25. args = parser.parse_args()
  26. return args
  27. def main():
  28. args = parse_args()
  29. devkit_path = args.devkit_path
  30. aug_path = args.aug_path
  31. nproc = args.nproc
  32. if args.out_dir is None:
  33. out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
  34. else:
  35. out_dir = args.out_dir
  36. mkdir_or_exist(out_dir)
  37. in_dir = osp.join(aug_path, 'dataset', 'cls')
  38. track_parallel_progress(
  39. partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
  40. list(scandir(in_dir, suffix='.mat')),
  41. nproc=nproc)
  42. full_aug_list = []
  43. with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
  44. full_aug_list += [line.strip() for line in f]
  45. with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
  46. full_aug_list += [line.strip() for line in f]
  47. with open(
  48. osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
  49. 'train.txt')) as f:
  50. ori_train_list = [line.strip() for line in f]
  51. with open(
  52. osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
  53. 'val.txt')) as f:
  54. val_list = [line.strip() for line in f]
  55. aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
  56. val_list)
  57. assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
  58. AUG_LEN)
  59. with open(
  60. osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
  61. 'trainaug.txt'), 'w') as f:
  62. f.writelines(line + '\n' for line in aug_train_list)
  63. aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
  64. assert len(aug_list) == AUG_LEN - len(
  65. ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
  66. len(ori_train_list))
  67. with open(
  68. osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
  69. 'w') as f:
  70. f.writelines(line + '\n' for line in aug_list)
  71. print('Done!')
  72. if __name__ == '__main__':
  73. main()