stare.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import gzip
  4. import os
  5. import os.path as osp
  6. import tarfile
  7. import tempfile
  8. import mmcv
  9. from mmengine.utils import mkdir_or_exist
  10. STARE_LEN = 20
  11. TRAINING_LEN = 10
  12. def un_gz(src, dst):
  13. g_file = gzip.GzipFile(src)
  14. with open(dst, 'wb+') as f:
  15. f.write(g_file.read())
  16. g_file.close()
  17. def parse_args():
  18. parser = argparse.ArgumentParser(
  19. description='Convert STARE dataset to mmsegmentation format')
  20. parser.add_argument('image_path', help='the path of stare-images.tar')
  21. parser.add_argument('labels_ah', help='the path of labels-ah.tar')
  22. parser.add_argument('labels_vk', help='the path of labels-vk.tar')
  23. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  24. parser.add_argument('-o', '--out_dir', help='output path')
  25. args = parser.parse_args()
  26. return args
  27. def main():
  28. args = parse_args()
  29. image_path = args.image_path
  30. labels_ah = args.labels_ah
  31. labels_vk = args.labels_vk
  32. if args.out_dir is None:
  33. out_dir = osp.join('data', 'STARE')
  34. else:
  35. out_dir = args.out_dir
  36. print('Making directories...')
  37. mkdir_or_exist(out_dir)
  38. mkdir_or_exist(osp.join(out_dir, 'images'))
  39. mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
  40. mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
  41. mkdir_or_exist(osp.join(out_dir, 'annotations'))
  42. mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
  43. mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
  44. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  45. mkdir_or_exist(osp.join(tmp_dir, 'gz'))
  46. mkdir_or_exist(osp.join(tmp_dir, 'files'))
  47. print('Extracting stare-images.tar...')
  48. with tarfile.open(image_path) as f:
  49. f.extractall(osp.join(tmp_dir, 'gz'))
  50. for filename in os.listdir(osp.join(tmp_dir, 'gz')):
  51. un_gz(
  52. osp.join(tmp_dir, 'gz', filename),
  53. osp.join(tmp_dir, 'files',
  54. osp.splitext(filename)[0]))
  55. now_dir = osp.join(tmp_dir, 'files')
  56. assert len(os.listdir(now_dir)) == STARE_LEN, \
  57. f'len(os.listdir(now_dir)) != {STARE_LEN}'
  58. for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
  59. img = mmcv.imread(osp.join(now_dir, filename))
  60. mmcv.imwrite(
  61. img,
  62. osp.join(out_dir, 'images', 'training',
  63. osp.splitext(filename)[0] + '.png'))
  64. for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
  65. img = mmcv.imread(osp.join(now_dir, filename))
  66. mmcv.imwrite(
  67. img,
  68. osp.join(out_dir, 'images', 'validation',
  69. osp.splitext(filename)[0] + '.png'))
  70. print('Removing the temporary files...')
  71. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  72. mkdir_or_exist(osp.join(tmp_dir, 'gz'))
  73. mkdir_or_exist(osp.join(tmp_dir, 'files'))
  74. print('Extracting labels-ah.tar...')
  75. with tarfile.open(labels_ah) as f:
  76. f.extractall(osp.join(tmp_dir, 'gz'))
  77. for filename in os.listdir(osp.join(tmp_dir, 'gz')):
  78. un_gz(
  79. osp.join(tmp_dir, 'gz', filename),
  80. osp.join(tmp_dir, 'files',
  81. osp.splitext(filename)[0]))
  82. now_dir = osp.join(tmp_dir, 'files')
  83. assert len(os.listdir(now_dir)) == STARE_LEN, \
  84. f'len(os.listdir(now_dir)) != {STARE_LEN}'
  85. for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
  86. img = mmcv.imread(osp.join(now_dir, filename))
  87. # The annotation img should be divided by 128, because some of
  88. # the annotation imgs are not standard. We should set a threshold
  89. # to convert the nonstandard annotation imgs. The value divided by
  90. # 128 equivalent to '1 if value >= 128 else 0'
  91. mmcv.imwrite(
  92. img[:, :, 0] // 128,
  93. osp.join(out_dir, 'annotations', 'training',
  94. osp.splitext(filename)[0] + '.png'))
  95. for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
  96. img = mmcv.imread(osp.join(now_dir, filename))
  97. mmcv.imwrite(
  98. img[:, :, 0] // 128,
  99. osp.join(out_dir, 'annotations', 'validation',
  100. osp.splitext(filename)[0] + '.png'))
  101. print('Removing the temporary files...')
  102. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  103. mkdir_or_exist(osp.join(tmp_dir, 'gz'))
  104. mkdir_or_exist(osp.join(tmp_dir, 'files'))
  105. print('Extracting labels-vk.tar...')
  106. with tarfile.open(labels_vk) as f:
  107. f.extractall(osp.join(tmp_dir, 'gz'))
  108. for filename in os.listdir(osp.join(tmp_dir, 'gz')):
  109. un_gz(
  110. osp.join(tmp_dir, 'gz', filename),
  111. osp.join(tmp_dir, 'files',
  112. osp.splitext(filename)[0]))
  113. now_dir = osp.join(tmp_dir, 'files')
  114. assert len(os.listdir(now_dir)) == STARE_LEN, \
  115. f'len(os.listdir(now_dir)) != {STARE_LEN}'
  116. for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
  117. img = mmcv.imread(osp.join(now_dir, filename))
  118. mmcv.imwrite(
  119. img[:, :, 0] // 128,
  120. osp.join(out_dir, 'annotations', 'training',
  121. osp.splitext(filename)[0] + '.png'))
  122. for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
  123. img = mmcv.imread(osp.join(now_dir, filename))
  124. mmcv.imwrite(
  125. img[:, :, 0] // 128,
  126. osp.join(out_dir, 'annotations', 'validation',
  127. osp.splitext(filename)[0] + '.png'))
  128. print('Removing the temporary files...')
  129. print('Done!')
  130. if __name__ == '__main__':
  131. main()