levircd.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import glob
  4. import math
  5. import os
  6. import os.path as osp
  7. import mmcv
  8. import numpy as np
  9. from mmengine.utils import ProgressBar
  10. def parse_args():
  11. parser = argparse.ArgumentParser(
  12. description='Convert levir-cd dataset to mmsegmentation format')
  13. parser.add_argument('--dataset_path', help='potsdam folder path')
  14. parser.add_argument('-o', '--out_dir', help='output path')
  15. parser.add_argument(
  16. '--clip_size',
  17. type=int,
  18. help='clipped size of image after preparation',
  19. default=256)
  20. parser.add_argument(
  21. '--stride_size',
  22. type=int,
  23. help='stride of clipping original images',
  24. default=256)
  25. args = parser.parse_args()
  26. return args
  27. def main():
  28. args = parse_args()
  29. input_folder = args.dataset_path
  30. png_files = glob.glob(
  31. os.path.join(input_folder, '**/*.png'), recursive=True)
  32. output_folder = args.out_dir
  33. prog_bar = ProgressBar(len(png_files))
  34. for png_file in png_files:
  35. new_path = os.path.join(
  36. output_folder,
  37. os.path.relpath(os.path.dirname(png_file), input_folder))
  38. os.makedirs(os.path.dirname(new_path), exist_ok=True)
  39. label = False
  40. if 'label' in png_file:
  41. label = True
  42. clip_big_image(png_file, new_path, args, label)
  43. prog_bar.update()
  44. def clip_big_image(image_path, clip_save_dir, args, to_label=False):
  45. image = mmcv.imread(image_path)
  46. h, w, c = image.shape
  47. clip_size = args.clip_size
  48. stride_size = args.stride_size
  49. num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
  50. (h - clip_size) /
  51. stride_size) * stride_size + clip_size >= h else math.ceil(
  52. (h - clip_size) / stride_size) + 1
  53. num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
  54. (w - clip_size) /
  55. stride_size) * stride_size + clip_size >= w else math.ceil(
  56. (w - clip_size) / stride_size) + 1
  57. x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
  58. xmin = x * clip_size
  59. ymin = y * clip_size
  60. xmin = xmin.ravel()
  61. ymin = ymin.ravel()
  62. xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
  63. np.zeros_like(xmin))
  64. ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
  65. np.zeros_like(ymin))
  66. boxes = np.stack([
  67. xmin + xmin_offset, ymin + ymin_offset,
  68. np.minimum(xmin + clip_size, w),
  69. np.minimum(ymin + clip_size, h)
  70. ],
  71. axis=1)
  72. if to_label:
  73. image[image == 255] = 1
  74. image = image[:, :, 0]
  75. for box in boxes:
  76. start_x, start_y, end_x, end_y = box
  77. clipped_image = image[start_y:end_y, start_x:end_x] \
  78. if to_label else image[start_y:end_y, start_x:end_x, :]
  79. idx = osp.basename(image_path).split('.')[0]
  80. mmcv.imwrite(
  81. clipped_image.astype(np.uint8),
  82. osp.join(clip_save_dir,
  83. f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
  84. if __name__ == '__main__':
  85. main()