| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import os
- import os.path as osp
- from PIL import Image
- import six
- import lmdb
- import pickle
- import copy
- import random
- import numpy as np
- import torch.utils.data as data
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
- from torchvision.datasets.folder import ImageFolder, IMG_EXTENSIONS
- # from util.data import get_img_loader
- from data.utils import get_transforms, get_scales
- from . import DATA
- import torch
- import json
- @DATA.register_module
- class ImageFolderLMDB(data.Dataset):
- def __init__(self, cfg, train=True, transform=None, target_transform=None):
- self.cfg = cfg
- self.train = train
- self.transform = transform
- self.target_transform = target_transform
- # scale_kwargs = cfg.trainer.scale_kwargs
- # if scale_kwargs is not None and scale_kwargs['n_scale'] > 0:
- # scale_kwargs = {k: v for k, v in scale_kwargs.items()}
- # self.scales = get_scales(**scale_kwargs)
- # else:
- # self.scales = [(cfg.size, cfg.size)]
- # self.num = 0
- # self.batch_size_per_gpu = cfg.trainer.data.batch_size_per_gpu
-
- self.loader = pickle.loads
- db_path = '{}/{}.lmdb'.format(cfg.data.root, 'train' if train else 'val')
- self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), readonly=True, lock=False, readahead=False, meminit=False)
- self.txn = self.env.begin(write=False)
- self.length = pickle.loads(self.txn.get(b'__len__'))
- self.keys = pickle.loads(self.txn.get(b'__keys__'))
-
- # def reset_scale_transform(self):
- # scale_rand = random.choices(self.scales, k=1)[0]
- # scale_rand = scale_rand[0]
- # self.cfg.size = scale_rand
- # self.cfg.data.train_transforms[0]['input_size'] = scale_rand
- # self.transform = get_transforms(self.cfg, train=True, cfg_transforms=self.cfg.data.train_transforms)
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, index):
- # if len(self.scales) > 1 and self.num % self.batch_size_per_gpu == 0:
- # self.reset_scale_transform()
- # self.num += 1
- byteflow = self.txn.get(self.keys[index])
- imgbuf, target = self.loader(byteflow)
- buf = six.BytesIO()
- buf.write(imgbuf)
- buf.seek(0)
- img = Image.open(buf).convert('RGB')
- img = self.transform(img) if self.transform is not None else img
- target = self.target_transform(target) if self.target_transform is not None else target
- return {'img': img, 'target':target}
- @DATA.register_module
- class CustomImageDataset(data.Dataset):
- def __init__(self, cfg, train=True, transform=None, target_transform=None):
- """
- Args:
- root_dir (string): Directory with all the images and json files.
- train (bool): If True, load train.json, else load val.json.
- transform (callable, optional): Optional transform to be applied
- on a sample.
- target_transform (callable, optional): Optional transform to be applied
- on the target.
- """
- self.root_dir = cfg.data.root
- self.train = train
- self.transform = transform
- self.target_transform = target_transform
- # Determine which json file to load
- json_file = 'train.json' if train else 'val.json'
- json_path = os.path.join(self.root_dir, json_file)
- # Load the json file
- with open(json_path, 'r') as f:
- self.annotations = json.load(f)
- # Create a list of image file names and their corresponding labels
- self.img_labels = list(self.annotations.items())
- self.length = len(self.img_labels)
- def __len__(self):
- return len(self.img_labels)
- def __getitem__(self, idx):
- if torch.is_tensor(idx):
- idx = idx.tolist()
- img_name, label = self.img_labels[idx]
- img_path = os.path.join(self.root_dir, 'images', img_name)
- image = Image.open(img_path).convert('RGB')
- if self.transform:
- image = self.transform(image)
- if self.target_transform:
- label = self.target_transform(label)
- sample = {'img': image, 'target': label}
- return sample
- def folder2lmdb(root, name="train", write_frequency=1000):
- # https://github.com/xunge/pytorch_lmdb_imagenet/blob/master/folder2lmdb.py
- def raw_reader(path):
- with open(path, 'rb') as f:
- bin_data = f.read()
- return bin_data
-
- img_dir = f'{root}/{name}'
- dataset = ImageFolder(root=img_dir, loader=raw_reader)
- data_loader = DataLoader(dataset, num_workers=32, collate_fn=lambda x: x)
-
- lmdb_path = osp.join(root, f'{name}.lmdb')
- db = lmdb.open(lmdb_path, subdir=True, map_size=1099511627776 * 2, readonly=False, meminit=False, map_async=True)
- txn = db.begin(write=True)
- for idx, data in enumerate(data_loader):
- image, label = data[0]
- txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps((image, label)))
- if (idx + 1) % write_frequency == 0:
- print(f'{name} {idx + 1}/{len(data_loader)}')
- txn.commit()
- txn = db.begin(write=True)
- txn.commit()
- keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
- txn = db.begin(write=True)
- txn.put(b'__keys__', pickle.dumps(keys))
- txn.put(b'__len__', pickle.dumps(len(keys)))
- txn.commit()
- db.sync()
- db.close()
-
|