isaid.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import os
  5. import os.path as osp
  6. import shutil
  7. import tempfile
  8. import zipfile
  9. import mmcv
  10. import numpy as np
  11. from mmengine.utils import ProgressBar, mkdir_or_exist
  12. from PIL import Image
  13. iSAID_palette = \
  14. {
  15. 0: (0, 0, 0),
  16. 1: (0, 0, 63),
  17. 2: (0, 63, 63),
  18. 3: (0, 63, 0),
  19. 4: (0, 63, 127),
  20. 5: (0, 63, 191),
  21. 6: (0, 63, 255),
  22. 7: (0, 127, 63),
  23. 8: (0, 127, 127),
  24. 9: (0, 0, 127),
  25. 10: (0, 0, 191),
  26. 11: (0, 0, 255),
  27. 12: (0, 191, 127),
  28. 13: (0, 127, 191),
  29. 14: (0, 127, 255),
  30. 15: (0, 100, 155)
  31. }
  32. iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}
  33. def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
  34. """RGB-color encoding to grayscale labels."""
  35. arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
  36. for c, i in palette.items():
  37. m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
  38. arr_2d[m] = i
  39. return arr_2d
  40. def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
  41. img = np.asarray(Image.open(src_path).convert('RGB'))
  42. img_H, img_W, _ = img.shape
  43. if img_H < patch_H and img_W > patch_W:
  44. img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
  45. img_H, img_W, _ = img.shape
  46. elif img_H > patch_H and img_W < patch_W:
  47. img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
  48. img_H, img_W, _ = img.shape
  49. elif img_H < patch_H and img_W < patch_W:
  50. img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
  51. img_H, img_W, _ = img.shape
  52. for x in range(0, img_W, patch_W - overlap):
  53. for y in range(0, img_H, patch_H - overlap):
  54. x_str = x
  55. x_end = x + patch_W
  56. if x_end > img_W:
  57. diff_x = x_end - img_W
  58. x_str -= diff_x
  59. x_end = img_W
  60. y_str = y
  61. y_end = y + patch_H
  62. if y_end > img_H:
  63. diff_y = y_end - img_H
  64. y_str -= diff_y
  65. y_end = img_H
  66. img_patch = img[y_str:y_end, x_str:x_end, :]
  67. img_patch = Image.fromarray(img_patch.astype(np.uint8))
  68. image = osp.basename(src_path).split('.')[0] + '_' + str(
  69. y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
  70. x_end) + '.png'
  71. # print(image)
  72. save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
  73. img_patch.save(save_path_image, format='BMP')
  74. def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
  75. label = mmcv.imread(src_path, channel_order='rgb')
  76. label = iSAID_convert_from_color(label)
  77. img_H, img_W = label.shape
  78. if img_H < patch_H and img_W > patch_W:
  79. label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
  80. img_H = patch_H
  81. elif img_H > patch_H and img_W < patch_W:
  82. label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
  83. img_W = patch_W
  84. elif img_H < patch_H and img_W < patch_W:
  85. label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
  86. img_H = patch_H
  87. img_W = patch_W
  88. for x in range(0, img_W, patch_W - overlap):
  89. for y in range(0, img_H, patch_H - overlap):
  90. x_str = x
  91. x_end = x + patch_W
  92. if x_end > img_W:
  93. diff_x = x_end - img_W
  94. x_str -= diff_x
  95. x_end = img_W
  96. y_str = y
  97. y_end = y + patch_H
  98. if y_end > img_H:
  99. diff_y = y_end - img_H
  100. y_str -= diff_y
  101. y_end = img_H
  102. lab_patch = label[y_str:y_end, x_str:x_end]
  103. lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P')
  104. image = osp.basename(src_path).split('.')[0].split(
  105. '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
  106. x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
  107. lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))
  108. def parse_args():
  109. parser = argparse.ArgumentParser(
  110. description='Convert iSAID dataset to mmsegmentation format')
  111. parser.add_argument('dataset_path', help='iSAID folder path')
  112. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  113. parser.add_argument('-o', '--out_dir', help='output path')
  114. parser.add_argument(
  115. '--patch_width',
  116. default=896,
  117. type=int,
  118. help='Width of the cropped image patch')
  119. parser.add_argument(
  120. '--patch_height',
  121. default=896,
  122. type=int,
  123. help='Height of the cropped image patch')
  124. parser.add_argument(
  125. '--overlap_area', default=384, type=int, help='Overlap area')
  126. args = parser.parse_args()
  127. return args
  128. def main():
  129. args = parse_args()
  130. dataset_path = args.dataset_path
  131. # image patch width and height
  132. patch_H, patch_W = args.patch_width, args.patch_height
  133. overlap = args.overlap_area # overlap area
  134. if args.out_dir is None:
  135. out_dir = osp.join('data', 'iSAID')
  136. else:
  137. out_dir = args.out_dir
  138. print('Making directories...')
  139. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
  140. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
  141. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test'))
  142. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
  143. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
  144. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test'))
  145. assert os.path.exists(os.path.join(dataset_path, 'train')), \
  146. f'train is not in {dataset_path}'
  147. assert os.path.exists(os.path.join(dataset_path, 'val')), \
  148. f'val is not in {dataset_path}'
  149. assert os.path.exists(os.path.join(dataset_path, 'test')), \
  150. f'test is not in {dataset_path}'
  151. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  152. for dataset_mode in ['train', 'val', 'test']:
  153. # for dataset_mode in [ 'test']:
  154. print(f'Extracting {dataset_mode}ing.zip...')
  155. img_zipp_list = glob.glob(
  156. os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
  157. print('Find the data', img_zipp_list)
  158. for img_zipp in img_zipp_list:
  159. zip_file = zipfile.ZipFile(img_zipp)
  160. zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
  161. src_path_list = glob.glob(
  162. os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))
  163. src_prog_bar = ProgressBar(len(src_path_list))
  164. for i, img_path in enumerate(src_path_list):
  165. if dataset_mode != 'test':
  166. slide_crop_image(img_path, out_dir, dataset_mode, patch_H,
  167. patch_W, overlap)
  168. else:
  169. shutil.move(img_path,
  170. os.path.join(out_dir, 'img_dir', dataset_mode))
  171. src_prog_bar.update()
  172. if dataset_mode != 'test':
  173. label_zipp_list = glob.glob(
  174. os.path.join(dataset_path, dataset_mode, 'Semantic_masks',
  175. '*.zip'))
  176. for label_zipp in label_zipp_list:
  177. zip_file = zipfile.ZipFile(label_zipp)
  178. zip_file.extractall(
  179. os.path.join(tmp_dir, dataset_mode, 'lab'))
  180. lab_path_list = glob.glob(
  181. os.path.join(tmp_dir, dataset_mode, 'lab', 'images',
  182. '*.png'))
  183. lab_prog_bar = ProgressBar(len(lab_path_list))
  184. for i, lab_path in enumerate(lab_path_list):
  185. slide_crop_label(lab_path, out_dir, dataset_mode, patch_H,
  186. patch_W, overlap)
  187. lab_prog_bar.update()
  188. print('Removing the temporary files...')
  189. print('Done!')
  190. if __name__ == '__main__':
  191. main()