synapse.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import nibabel as nib
  5. import numpy as np
  6. from mmengine.utils import mkdir_or_exist
  7. from PIL import Image
  8. def read_files_from_txt(txt_path):
  9. with open(txt_path) as f:
  10. files = f.readlines()
  11. files = [file.strip() for file in files]
  12. return files
  13. def read_nii_file(nii_path):
  14. img = nib.load(nii_path).get_fdata()
  15. return img
  16. def split_3d_image(img):
  17. c, _, _ = img.shape
  18. res = []
  19. for i in range(c):
  20. res.append(img[i, :, :])
  21. return res
  22. def label_mapping(label):
  23. """Label mapping from TransUNet paper setting. It only has 9 classes, which
  24. are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
  25. 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
  26. classes in original dataset are all set to background.
  27. More details could be found here: https://arxiv.org/abs/2102.04306
  28. """
  29. maped_label = np.zeros_like(label)
  30. maped_label[label == 8] = 1
  31. maped_label[label == 4] = 2
  32. maped_label[label == 3] = 3
  33. maped_label[label == 2] = 4
  34. maped_label[label == 6] = 5
  35. maped_label[label == 11] = 6
  36. maped_label[label == 1] = 7
  37. maped_label[label == 7] = 8
  38. return maped_label
  39. def pares_args():
  40. parser = argparse.ArgumentParser(
  41. description='Convert synapse dataset to mmsegmentation format')
  42. parser.add_argument(
  43. '--dataset-path', type=str, help='synapse dataset path.')
  44. parser.add_argument(
  45. '--save-path',
  46. default='data/synapse',
  47. type=str,
  48. help='save path of the dataset.')
  49. args = parser.parse_args()
  50. return args
  51. def main():
  52. args = pares_args()
  53. dataset_path = args.dataset_path
  54. save_path = args.save_path
  55. if not osp.exists(dataset_path):
  56. raise ValueError('The dataset path does not exist. '
  57. 'Please enter a correct dataset path.')
  58. if not osp.exists(osp.join(dataset_path, 'img')) \
  59. or not osp.exists(osp.join(dataset_path, 'label')):
  60. raise FileNotFoundError('The dataset structure is incorrect. '
  61. 'Please check your dataset.')
  62. train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
  63. train_id = [idx[3:7] for idx in train_id]
  64. test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
  65. test_id = [idx[3:7] for idx in test_id]
  66. mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
  67. mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
  68. mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
  69. mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
  70. # It follows data preparation pipeline from here:
  71. # https://github.com/Beckschen/TransUNet/tree/main/datasets
  72. for i, idx in enumerate(train_id):
  73. img_3d = read_nii_file(
  74. osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
  75. label_3d = read_nii_file(
  76. osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
  77. img_3d = np.clip(img_3d, -125, 275)
  78. img_3d = (img_3d + 125) / 400
  79. img_3d *= 255
  80. img_3d = np.transpose(img_3d, [2, 0, 1])
  81. img_3d = np.flip(img_3d, 2)
  82. label_3d = np.transpose(label_3d, [2, 0, 1])
  83. label_3d = np.flip(label_3d, 2)
  84. label_3d = label_mapping(label_3d)
  85. for c in range(img_3d.shape[0]):
  86. img = img_3d[c]
  87. label = label_3d[c]
  88. img = Image.fromarray(img).convert('RGB')
  89. label = Image.fromarray(label).convert('L')
  90. img.save(
  91. osp.join(
  92. save_path, 'img_dir/train', 'case' + idx.zfill(4) +
  93. '_slice' + str(c).zfill(3) + '.jpg'))
  94. label.save(
  95. osp.join(
  96. save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
  97. '_slice' + str(c).zfill(3) + '.png'))
  98. for i, idx in enumerate(test_id):
  99. img_3d = read_nii_file(
  100. osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
  101. label_3d = read_nii_file(
  102. osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
  103. img_3d = np.clip(img_3d, -125, 275)
  104. img_3d = (img_3d + 125) / 400
  105. img_3d *= 255
  106. img_3d = np.transpose(img_3d, [2, 0, 1])
  107. img_3d = np.flip(img_3d, 2)
  108. label_3d = np.transpose(label_3d, [2, 0, 1])
  109. label_3d = np.flip(label_3d, 2)
  110. label_3d = label_mapping(label_3d)
  111. for c in range(img_3d.shape[0]):
  112. img = img_3d[c]
  113. label = label_3d[c]
  114. img = Image.fromarray(img).convert('RGB')
  115. label = Image.fromarray(label).convert('L')
  116. img.save(
  117. osp.join(
  118. save_path, 'img_dir/val', 'case' + idx.zfill(4) +
  119. '_slice' + str(c).zfill(3) + '.jpg'))
  120. label.save(
  121. osp.join(
  122. save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
  123. '_slice' + str(c).zfill(3) + '.png'))
  124. if __name__ == '__main__':
  125. main()