lmdb_dataset.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import os.path as osp
  3. from PIL import Image
  4. import six
  5. import lmdb
  6. import pickle
  7. import copy
  8. import random
  9. import numpy as np
  10. import torch.utils.data as data
  11. from torch.utils.data import DataLoader
  12. from torchvision import datasets, transforms
  13. from torchvision.datasets.folder import ImageFolder, IMG_EXTENSIONS
  14. # from util.data import get_img_loader
  15. from data.utils import get_transforms, get_scales
  16. from . import DATA
  17. import torch
  18. import json
  19. @DATA.register_module
  20. class ImageFolderLMDB(data.Dataset):
  21. def __init__(self, cfg, train=True, transform=None, target_transform=None):
  22. self.cfg = cfg
  23. self.train = train
  24. self.transform = transform
  25. self.target_transform = target_transform
  26. # scale_kwargs = cfg.trainer.scale_kwargs
  27. # if scale_kwargs is not None and scale_kwargs['n_scale'] > 0:
  28. # scale_kwargs = {k: v for k, v in scale_kwargs.items()}
  29. # self.scales = get_scales(**scale_kwargs)
  30. # else:
  31. # self.scales = [(cfg.size, cfg.size)]
  32. # self.num = 0
  33. # self.batch_size_per_gpu = cfg.trainer.data.batch_size_per_gpu
  34. self.loader = pickle.loads
  35. db_path = '{}/{}.lmdb'.format(cfg.data.root, 'train' if train else 'val')
  36. self.env = lmdb.open(db_path, subdir=osp.isdir(db_path), readonly=True, lock=False, readahead=False, meminit=False)
  37. self.txn = self.env.begin(write=False)
  38. self.length = pickle.loads(self.txn.get(b'__len__'))
  39. self.keys = pickle.loads(self.txn.get(b'__keys__'))
  40. # def reset_scale_transform(self):
  41. # scale_rand = random.choices(self.scales, k=1)[0]
  42. # scale_rand = scale_rand[0]
  43. # self.cfg.size = scale_rand
  44. # self.cfg.data.train_transforms[0]['input_size'] = scale_rand
  45. # self.transform = get_transforms(self.cfg, train=True, cfg_transforms=self.cfg.data.train_transforms)
  46. def __len__(self):
  47. return self.length
  48. def __getitem__(self, index):
  49. # if len(self.scales) > 1 and self.num % self.batch_size_per_gpu == 0:
  50. # self.reset_scale_transform()
  51. # self.num += 1
  52. byteflow = self.txn.get(self.keys[index])
  53. imgbuf, target = self.loader(byteflow)
  54. buf = six.BytesIO()
  55. buf.write(imgbuf)
  56. buf.seek(0)
  57. img = Image.open(buf).convert('RGB')
  58. img = self.transform(img) if self.transform is not None else img
  59. target = self.target_transform(target) if self.target_transform is not None else target
  60. return {'img': img, 'target':target}
  61. @DATA.register_module
  62. class CustomImageDataset(data.Dataset):
  63. def __init__(self, cfg, train=True, transform=None, target_transform=None):
  64. """
  65. Args:
  66. root_dir (string): Directory with all the images and json files.
  67. train (bool): If True, load train.json, else load val.json.
  68. transform (callable, optional): Optional transform to be applied
  69. on a sample.
  70. target_transform (callable, optional): Optional transform to be applied
  71. on the target.
  72. """
  73. self.root_dir = cfg.data.root
  74. self.train = train
  75. self.transform = transform
  76. self.target_transform = target_transform
  77. # Determine which json file to load
  78. json_file = 'train.json' if train else 'val.json'
  79. json_path = os.path.join(self.root_dir, json_file)
  80. # Load the json file
  81. with open(json_path, 'r') as f:
  82. self.annotations = json.load(f)
  83. # Create a list of image file names and their corresponding labels
  84. self.img_labels = list(self.annotations.items())
  85. self.length = len(self.img_labels)
  86. def __len__(self):
  87. return len(self.img_labels)
  88. def __getitem__(self, idx):
  89. if torch.is_tensor(idx):
  90. idx = idx.tolist()
  91. img_name, label = self.img_labels[idx]
  92. img_path = os.path.join(self.root_dir, 'images', img_name)
  93. image = Image.open(img_path).convert('RGB')
  94. if self.transform:
  95. image = self.transform(image)
  96. if self.target_transform:
  97. label = self.target_transform(label)
  98. sample = {'img': image, 'target': label}
  99. return sample
  100. def folder2lmdb(root, name="train", write_frequency=1000):
  101. # https://github.com/xunge/pytorch_lmdb_imagenet/blob/master/folder2lmdb.py
  102. def raw_reader(path):
  103. with open(path, 'rb') as f:
  104. bin_data = f.read()
  105. return bin_data
  106. img_dir = f'{root}/{name}'
  107. dataset = ImageFolder(root=img_dir, loader=raw_reader)
  108. data_loader = DataLoader(dataset, num_workers=32, collate_fn=lambda x: x)
  109. lmdb_path = osp.join(root, f'{name}.lmdb')
  110. db = lmdb.open(lmdb_path, subdir=True, map_size=1099511627776 * 2, readonly=False, meminit=False, map_async=True)
  111. txn = db.begin(write=True)
  112. for idx, data in enumerate(data_loader):
  113. image, label = data[0]
  114. txn.put(u'{}'.format(idx).encode('ascii'), pickle.dumps((image, label)))
  115. if (idx + 1) % write_frequency == 0:
  116. print(f'{name} {idx + 1}/{len(data_loader)}')
  117. txn.commit()
  118. txn = db.begin(write=True)
  119. txn.commit()
  120. keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
  121. txn = db.begin(write=True)
  122. txn.put(b'__keys__', pickle.dumps(keys))
  123. txn.put(b'__len__', pickle.dumps(len(keys)))
  124. txn.commit()
  125. db.sync()
  126. db.close()