drive.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os
  4. import os.path as osp
  5. import tempfile
  6. import zipfile
  7. import cv2
  8. import mmcv
  9. from mmengine.utils import mkdir_or_exist
  10. def parse_args():
  11. parser = argparse.ArgumentParser(
  12. description='Convert DRIVE dataset to mmsegmentation format')
  13. parser.add_argument(
  14. 'training_path', help='the training part of DRIVE dataset')
  15. parser.add_argument(
  16. 'testing_path', help='the testing part of DRIVE dataset')
  17. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  18. parser.add_argument('-o', '--out_dir', help='output path')
  19. args = parser.parse_args()
  20. return args
  21. def main():
  22. args = parse_args()
  23. training_path = args.training_path
  24. testing_path = args.testing_path
  25. if args.out_dir is None:
  26. out_dir = osp.join('data', 'DRIVE')
  27. else:
  28. out_dir = args.out_dir
  29. print('Making directories...')
  30. mkdir_or_exist(out_dir)
  31. mkdir_or_exist(osp.join(out_dir, 'images'))
  32. mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
  33. mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
  34. mkdir_or_exist(osp.join(out_dir, 'annotations'))
  35. mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
  36. mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
  37. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  38. print('Extracting training.zip...')
  39. zip_file = zipfile.ZipFile(training_path)
  40. zip_file.extractall(tmp_dir)
  41. print('Generating training dataset...')
  42. now_dir = osp.join(tmp_dir, 'training', 'images')
  43. for img_name in os.listdir(now_dir):
  44. img = mmcv.imread(osp.join(now_dir, img_name))
  45. mmcv.imwrite(
  46. img,
  47. osp.join(
  48. out_dir, 'images', 'training',
  49. osp.splitext(img_name)[0].replace('_training', '') +
  50. '.png'))
  51. now_dir = osp.join(tmp_dir, 'training', '1st_manual')
  52. for img_name in os.listdir(now_dir):
  53. cap = cv2.VideoCapture(osp.join(now_dir, img_name))
  54. ret, img = cap.read()
  55. mmcv.imwrite(
  56. img[:, :, 0] // 128,
  57. osp.join(out_dir, 'annotations', 'training',
  58. osp.splitext(img_name)[0] + '.png'))
  59. print('Extracting test.zip...')
  60. zip_file = zipfile.ZipFile(testing_path)
  61. zip_file.extractall(tmp_dir)
  62. print('Generating validation dataset...')
  63. now_dir = osp.join(tmp_dir, 'test', 'images')
  64. for img_name in os.listdir(now_dir):
  65. img = mmcv.imread(osp.join(now_dir, img_name))
  66. mmcv.imwrite(
  67. img,
  68. osp.join(
  69. out_dir, 'images', 'validation',
  70. osp.splitext(img_name)[0].replace('_test', '') + '.png'))
  71. now_dir = osp.join(tmp_dir, 'test', '1st_manual')
  72. if osp.exists(now_dir):
  73. for img_name in os.listdir(now_dir):
  74. cap = cv2.VideoCapture(osp.join(now_dir, img_name))
  75. ret, img = cap.read()
  76. # The annotation img should be divided by 128, because some of
  77. # the annotation imgs are not standard. We should set a
  78. # threshold to convert the nonstandard annotation imgs. The
  79. # value divided by 128 is equivalent to '1 if value >= 128
  80. # else 0'
  81. mmcv.imwrite(
  82. img[:, :, 0] // 128,
  83. osp.join(out_dir, 'annotations', 'validation',
  84. osp.splitext(img_name)[0] + '.png'))
  85. now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
  86. if osp.exists(now_dir):
  87. for img_name in os.listdir(now_dir):
  88. cap = cv2.VideoCapture(osp.join(now_dir, img_name))
  89. ret, img = cap.read()
  90. mmcv.imwrite(
  91. img[:, :, 0] // 128,
  92. osp.join(out_dir, 'annotations', 'validation',
  93. osp.splitext(img_name)[0] + '.png'))
  94. print('Removing the temporary files...')
  95. print('Done!')
  96. if __name__ == '__main__':
  97. main()