chase_db1.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import tempfile
  6. import zipfile
  7. import mmcv
  8. from mmengine.utils import mkdir_or_exist
  9. CHASE_DB1_LEN = 28 * 3
  10. TRAINING_LEN = 60
  11. def parse_args():
  12. parser = argparse.ArgumentParser(
  13. description='Convert CHASE_DB1 dataset to mmsegmentation format')
  14. parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
  15. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  16. parser.add_argument('-o', '--out_dir', help='output path')
  17. args = parser.parse_args()
  18. return args
  19. def main():
  20. args = parse_args()
  21. dataset_path = args.dataset_path
  22. if args.out_dir is None:
  23. out_dir = osp.join('data', 'CHASE_DB1')
  24. else:
  25. out_dir = args.out_dir
  26. print('Making directories...')
  27. mkdir_or_exist(out_dir)
  28. mkdir_or_exist(osp.join(out_dir, 'images'))
  29. mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
  30. mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
  31. mkdir_or_exist(osp.join(out_dir, 'annotations'))
  32. mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
  33. mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
  34. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  35. print('Extracting CHASEDB1.zip...')
  36. zip_file = zipfile.ZipFile(dataset_path)
  37. zip_file.extractall(tmp_dir)
  38. print('Generating training dataset...')
  39. assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
  40. f'len(os.listdir(tmp_dir)) != {CHASE_DB1_LEN}'
  41. for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
  42. img = mmcv.imread(osp.join(tmp_dir, img_name))
  43. if osp.splitext(img_name)[1] == '.jpg':
  44. mmcv.imwrite(
  45. img,
  46. osp.join(out_dir, 'images', 'training',
  47. osp.splitext(img_name)[0] + '.png'))
  48. else:
  49. # The annotation img should be divided by 128, because some of
  50. # the annotation imgs are not standard. We should set a
  51. # threshold to convert the nonstandard annotation imgs. The
  52. # value divided by 128 is equivalent to '1 if value >= 128
  53. # else 0'
  54. mmcv.imwrite(
  55. img[:, :, 0] // 128,
  56. osp.join(out_dir, 'annotations', 'training',
  57. osp.splitext(img_name)[0] + '.png'))
  58. for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
  59. img = mmcv.imread(osp.join(tmp_dir, img_name))
  60. if osp.splitext(img_name)[1] == '.jpg':
  61. mmcv.imwrite(
  62. img,
  63. osp.join(out_dir, 'images', 'validation',
  64. osp.splitext(img_name)[0] + '.png'))
  65. else:
  66. mmcv.imwrite(
  67. img[:, :, 0] // 128,
  68. osp.join(out_dir, 'annotations', 'validation',
  69. osp.splitext(img_name)[0] + '.png'))
  70. print('Removing the temporary files...')
  71. print('Done!')
  72. if __name__ == '__main__':
  73. main()