potsdam.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import math
  5. import os
  6. import os.path as osp
  7. import tempfile
  8. import zipfile
  9. import mmcv
  10. import numpy as np
  11. from mmengine.utils import ProgressBar, mkdir_or_exist
  12. def parse_args():
  13. parser = argparse.ArgumentParser(
  14. description='Convert potsdam dataset to mmsegmentation format')
  15. parser.add_argument('dataset_path', help='potsdam folder path')
  16. parser.add_argument('--tmp_dir', help='path of the temporary directory')
  17. parser.add_argument('-o', '--out_dir', help='output path')
  18. parser.add_argument(
  19. '--clip_size',
  20. type=int,
  21. help='clipped size of image after preparation',
  22. default=512)
  23. parser.add_argument(
  24. '--stride_size',
  25. type=int,
  26. help='stride of clipping original images',
  27. default=256)
  28. args = parser.parse_args()
  29. return args
  30. def clip_big_image(image_path, clip_save_dir, args, to_label=False):
  31. # Original image of Potsdam dataset is very large, thus pre-processing
  32. # of them is adopted. Given fixed clip size and stride size to generate
  33. # clipped image, the intersection of width and height is determined.
  34. # For example, given one 5120 x 5120 original image, the clip size is
  35. # 512 and stride size is 256, thus it would generate 20x20 = 400 images
  36. # whose size are all 512x512.
  37. image = mmcv.imread(image_path)
  38. h, w, c = image.shape
  39. clip_size = args.clip_size
  40. stride_size = args.stride_size
  41. num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
  42. (h - clip_size) /
  43. stride_size) * stride_size + clip_size >= h else math.ceil(
  44. (h - clip_size) / stride_size) + 1
  45. num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
  46. (w - clip_size) /
  47. stride_size) * stride_size + clip_size >= w else math.ceil(
  48. (w - clip_size) / stride_size) + 1
  49. x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
  50. xmin = x * clip_size
  51. ymin = y * clip_size
  52. xmin = xmin.ravel()
  53. ymin = ymin.ravel()
  54. xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
  55. np.zeros_like(xmin))
  56. ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
  57. np.zeros_like(ymin))
  58. boxes = np.stack([
  59. xmin + xmin_offset, ymin + ymin_offset,
  60. np.minimum(xmin + clip_size, w),
  61. np.minimum(ymin + clip_size, h)
  62. ],
  63. axis=1)
  64. if to_label:
  65. color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0],
  66. [255, 255, 0], [0, 255, 0], [0, 255, 255],
  67. [0, 0, 255]])
  68. flatten_v = np.matmul(
  69. image.reshape(-1, c),
  70. np.array([2, 3, 4]).reshape(3, 1))
  71. out = np.zeros_like(flatten_v)
  72. for idx, class_color in enumerate(color_map):
  73. value_idx = np.matmul(class_color,
  74. np.array([2, 3, 4]).reshape(3, 1))
  75. out[flatten_v == value_idx] = idx
  76. image = out.reshape(h, w)
  77. for box in boxes:
  78. start_x, start_y, end_x, end_y = box
  79. clipped_image = image[start_y:end_y,
  80. start_x:end_x] if to_label else image[
  81. start_y:end_y, start_x:end_x, :]
  82. idx_i, idx_j = osp.basename(image_path).split('_')[2:4]
  83. mmcv.imwrite(
  84. clipped_image.astype(np.uint8),
  85. osp.join(
  86. clip_save_dir,
  87. f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
  88. def main():
  89. args = parse_args()
  90. splits = {
  91. 'train': [
  92. '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11',
  93. '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7',
  94. '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9'
  95. ],
  96. 'val': [
  97. '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13',
  98. '4_15', '2_14', '5_13', '4_13', '3_14', '7_13'
  99. ]
  100. }
  101. dataset_path = args.dataset_path
  102. if args.out_dir is None:
  103. out_dir = osp.join('data', 'potsdam')
  104. else:
  105. out_dir = args.out_dir
  106. print('Making directories...')
  107. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train'))
  108. mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val'))
  109. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train'))
  110. mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val'))
  111. zipp_list = glob.glob(os.path.join(dataset_path, '*.zip'))
  112. print('Find the data', zipp_list)
  113. for zipp in zipp_list:
  114. with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
  115. zip_file = zipfile.ZipFile(zipp)
  116. zip_file.extractall(tmp_dir)
  117. src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif'))
  118. if not len(src_path_list):
  119. sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0])
  120. src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif'))
  121. prog_bar = ProgressBar(len(src_path_list))
  122. for i, src_path in enumerate(src_path_list):
  123. idx_i, idx_j = osp.basename(src_path).split('_')[2:4]
  124. data_type = 'train' if f'{idx_i}_{idx_j}' in splits[
  125. 'train'] else 'val'
  126. if 'label' in src_path:
  127. dst_dir = osp.join(out_dir, 'ann_dir', data_type)
  128. clip_big_image(src_path, dst_dir, args, to_label=True)
  129. else:
  130. dst_dir = osp.join(out_dir, 'img_dir', data_type)
  131. clip_big_image(src_path, dst_dir, args, to_label=False)
  132. prog_bar.update()
  133. print('Removing the temporary files...')
  134. print('Done!')
  135. if __name__ == '__main__':
  136. main()