refuge.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import tempfile
  6. import zipfile
  7. import mmcv
  8. import numpy as np
  9. from mmengine.utils import mkdir_or_exist
  10. def parse_args():
  11. parser = argparse.ArgumentParser(
  12. description='Convert REFUGE dataset to mmsegmentation format')
  13. parser.add_argument('--raw_data_root', help='the root path of raw data')
  14. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  15. parser.add_argument('-o', '--out_dir', help='output path')
  16. args = parser.parse_args()
  17. return args
  18. def extract_img(root: str,
  19. cur_dir: str,
  20. out_dir: str,
  21. mode: str = 'train',
  22. file_type: str = 'img') -> None:
  23. """_summary_
  24. Args:
  25. Args:
  26. root (str): root where the extracted data is saved
  27. cur_dir (cur_dir): dir where the zip_file exists
  28. out_dir (str): root dir where the data is saved
  29. mode (str, optional): Defaults to 'train'.
  30. file_type (str, optional): Defaults to 'img',else to 'mask'.
  31. """
  32. zip_file = zipfile.ZipFile(cur_dir)
  33. zip_file.extractall(root)
  34. for cur_dir, dirs, files in os.walk(root):
  35. # filter child dirs and directories with "Illustration" and "MACOSX"
  36. if len(dirs) == 0 and \
  37. cur_dir.split('\\')[-1].find('Illustration') == -1 and \
  38. cur_dir.find('MACOSX') == -1:
  39. file_names = [
  40. file for file in files
  41. if file.endswith('.jpg') or file.endswith('.bmp')
  42. ]
  43. for filename in sorted(file_names):
  44. img = mmcv.imread(osp.join(cur_dir, filename))
  45. if file_type == 'annotations':
  46. img = img[:, :, 0]
  47. img[np.where(img == 0)] = 1
  48. img[np.where(img == 128)] = 2
  49. img[np.where(img == 255)] = 0
  50. mmcv.imwrite(
  51. img,
  52. osp.join(out_dir, file_type, mode,
  53. osp.splitext(filename)[0] + '.png'))
  54. def main():
  55. args = parse_args()
  56. raw_data_root = args.raw_data_root
  57. if args.out_dir is None:
  58. out_dir = osp.join('./data', 'REFUGE')
  59. else:
  60. out_dir = args.out_dir
  61. print('Making directories...')
  62. mkdir_or_exist(out_dir)
  63. mkdir_or_exist(osp.join(out_dir, 'images'))
  64. mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
  65. mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
  66. mkdir_or_exist(osp.join(out_dir, 'images', 'test'))
  67. mkdir_or_exist(osp.join(out_dir, 'annotations'))
  68. mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
  69. mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
  70. mkdir_or_exist(osp.join(out_dir, 'annotations', 'test'))
  71. print('Generating images and annotations...')
  72. # process data from the child dir on the first rank
  73. cur_dir, dirs, files = list(os.walk(raw_data_root))[0]
  74. print('====================')
  75. files = list(filter(lambda x: x.endswith('.zip'), files))
  76. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  77. for file in files:
  78. # search data folders for training,validation,test
  79. mode = list(
  80. filter(lambda x: file.lower().find(x) != -1,
  81. ['training', 'test', 'validation']))[0]
  82. file_root = osp.join(tmp_dir, file[:-4])
  83. file_type = 'images' if file.find('Anno') == -1 and file.find(
  84. 'GT') == -1 else 'annotations'
  85. extract_img(file_root, osp.join(cur_dir, file), out_dir, mode,
  86. file_type)
  87. print('Done!')
  88. if __name__ == '__main__':
  89. main()