| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- # Copyright (c) OpenMMLab. All rights reserved.
- import argparse
- import os.path as osp
- from functools import partial
- import numpy as np
- from mmengine.utils import mkdir_or_exist, scandir, track_parallel_progress
- from PIL import Image
- from scipy.io import loadmat
- AUG_LEN = 10582
- def convert_mat(mat_file, in_dir, out_dir):
- data = loadmat(osp.join(in_dir, mat_file))
- mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
- seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
- Image.fromarray(mask).save(seg_filename, 'PNG')
- def generate_aug_list(merged_list, excluded_list):
- return list(set(merged_list) - set(excluded_list))
- def parse_args():
- parser = argparse.ArgumentParser(
- description='Convert PASCAL VOC annotations to mmsegmentation format')
- parser.add_argument('devkit_path', help='pascal voc devkit path')
- parser.add_argument('aug_path', help='pascal voc aug path')
- parser.add_argument('-o', '--out_dir', help='output path')
- parser.add_argument(
- '--nproc', default=1, type=int, help='number of process')
- args = parser.parse_args()
- return args
- def main():
- args = parse_args()
- devkit_path = args.devkit_path
- aug_path = args.aug_path
- nproc = args.nproc
- if args.out_dir is None:
- out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
- else:
- out_dir = args.out_dir
- mkdir_or_exist(out_dir)
- in_dir = osp.join(aug_path, 'dataset', 'cls')
- track_parallel_progress(
- partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
- list(scandir(in_dir, suffix='.mat')),
- nproc=nproc)
- full_aug_list = []
- with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
- full_aug_list += [line.strip() for line in f]
- with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
- full_aug_list += [line.strip() for line in f]
- with open(
- osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
- 'train.txt')) as f:
- ori_train_list = [line.strip() for line in f]
- with open(
- osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
- 'val.txt')) as f:
- val_list = [line.strip() for line in f]
- aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
- val_list)
- assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
- AUG_LEN)
- with open(
- osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
- 'trainaug.txt'), 'w') as f:
- f.writelines(line + '\n' for line in aug_train_list)
- aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
- assert len(aug_list) == AUG_LEN - len(
- ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
- len(ori_train_list))
- with open(
- osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
- 'w') as f:
- f.writelines(line + '\n' for line in aug_list)
- print('Done!')
- if __name__ == '__main__':
- main()
|