| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os.path as osp
- import nibabel as nib
- import numpy as np
- from mmengine.utils import mkdir_or_exist
- from PIL import Image
- def read_files_from_txt(txt_path):
- with open(txt_path) as f:
- files = f.readlines()
- files = [file.strip() for file in files]
- return files
- def read_nii_file(nii_path):
- img = nib.load(nii_path).get_fdata()
- return img
- def split_3d_image(img):
- c, _, _ = img.shape
- res = []
- for i in range(c):
- res.append(img[i, :, :])
- return res
- def label_mapping(label):
- """Label mapping from TransUNet paper setting. It only has 9 classes, which
- are 'background', 'aorta', 'gallbladder', 'left_kidney', 'right_kidney',
- 'liver', 'pancreas', 'spleen', 'stomach', respectively. Other foreground
- classes in original dataset are all set to background.
- More details could be found here: https://arxiv.org/abs/2102.04306
- """
- maped_label = np.zeros_like(label)
- maped_label[label == 8] = 1
- maped_label[label == 4] = 2
- maped_label[label == 3] = 3
- maped_label[label == 2] = 4
- maped_label[label == 6] = 5
- maped_label[label == 11] = 6
- maped_label[label == 1] = 7
- maped_label[label == 7] = 8
- return maped_label
- def pares_args():
- parser = argparse.ArgumentParser(
- description='Convert synapse dataset to mmsegmentation format')
- parser.add_argument(
- '--dataset-path', type=str, help='synapse dataset path.')
- parser.add_argument(
- '--save-path',
- default='data/synapse',
- type=str,
- help='save path of the dataset.')
- args = parser.parse_args()
- return args
- def main():
- args = pares_args()
- dataset_path = args.dataset_path
- save_path = args.save_path
- if not osp.exists(dataset_path):
- raise ValueError('The dataset path does not exist. '
- 'Please enter a correct dataset path.')
- if not osp.exists(osp.join(dataset_path, 'img')) \
- or not osp.exists(osp.join(dataset_path, 'label')):
- raise FileNotFoundError('The dataset structure is incorrect. '
- 'Please check your dataset.')
- train_id = read_files_from_txt(osp.join(dataset_path, 'train.txt'))
- train_id = [idx[3:7] for idx in train_id]
- test_id = read_files_from_txt(osp.join(dataset_path, 'val.txt'))
- test_id = [idx[3:7] for idx in test_id]
- mkdir_or_exist(osp.join(save_path, 'img_dir/train'))
- mkdir_or_exist(osp.join(save_path, 'img_dir/val'))
- mkdir_or_exist(osp.join(save_path, 'ann_dir/train'))
- mkdir_or_exist(osp.join(save_path, 'ann_dir/val'))
- # It follows data preparation pipeline from here:
- # https://github.com/Beckschen/TransUNet/tree/main/datasets
- for i, idx in enumerate(train_id):
- img_3d = read_nii_file(
- osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
- label_3d = read_nii_file(
- osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
- img_3d = np.clip(img_3d, -125, 275)
- img_3d = (img_3d + 125) / 400
- img_3d *= 255
- img_3d = np.transpose(img_3d, [2, 0, 1])
- img_3d = np.flip(img_3d, 2)
- label_3d = np.transpose(label_3d, [2, 0, 1])
- label_3d = np.flip(label_3d, 2)
- label_3d = label_mapping(label_3d)
- for c in range(img_3d.shape[0]):
- img = img_3d[c]
- label = label_3d[c]
- img = Image.fromarray(img).convert('RGB')
- label = Image.fromarray(label).convert('L')
- img.save(
- osp.join(
- save_path, 'img_dir/train', 'case' + idx.zfill(4) +
- '_slice' + str(c).zfill(3) + '.jpg'))
- label.save(
- osp.join(
- save_path, 'ann_dir/train', 'case' + idx.zfill(4) +
- '_slice' + str(c).zfill(3) + '.png'))
- for i, idx in enumerate(test_id):
- img_3d = read_nii_file(
- osp.join(dataset_path, 'img', 'img' + idx + '.nii.gz'))
- label_3d = read_nii_file(
- osp.join(dataset_path, 'label', 'label' + idx + '.nii.gz'))
- img_3d = np.clip(img_3d, -125, 275)
- img_3d = (img_3d + 125) / 400
- img_3d *= 255
- img_3d = np.transpose(img_3d, [2, 0, 1])
- img_3d = np.flip(img_3d, 2)
- label_3d = np.transpose(label_3d, [2, 0, 1])
- label_3d = np.flip(label_3d, 2)
- label_3d = label_mapping(label_3d)
- for c in range(img_3d.shape[0]):
- img = img_3d[c]
- label = label_3d[c]
- img = Image.fromarray(img).convert('RGB')
- label = Image.fromarray(label).convert('L')
- img.save(
- osp.join(
- save_path, 'img_dir/val', 'case' + idx.zfill(4) +
- '_slice' + str(c).zfill(3) + '.jpg'))
- label.save(
- osp.join(
- save_path, 'ann_dir/val', 'case' + idx.zfill(4) +
- '_slice' + str(c).zfill(3) + '.png'))
- if __name__ == '__main__':
- main()
|