CLS_dataset.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import os
  2. import glob
  3. import json
  4. import random
  5. from torch.utils.data import dataset
  6. from torchvision import datasets, transforms
  7. from torchvision.datasets.folder import ImageFolder, IMG_EXTENSIONS
  8. from util.data import get_img_loader
  9. from data.utils import get_transforms, get_scales
  10. from typing import Any, Callable, cast, Dict, List, Optional, Tuple
  11. import torch.utils.data as data
  12. import numpy as np
  13. from PIL import Image
  14. import warnings
  15. warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
  16. from . import DATA
  17. @DATA.register_module
  18. class DefaultCLS(datasets.folder.DatasetFolder): # ImageNet
  19. def __init__(self, cfg, train=True, transform=None, target_transform=None):
  20. root = '{}/{}'.format(cfg.data.root, 'train' if train else 'val')
  21. img_loader = get_img_loader(cfg.data.loader_type)
  22. super(DefaultCLS, self).__init__(root=root, loader=img_loader, extensions=IMG_EXTENSIONS, transform=transform, target_transform=target_transform)
  23. self.cfg = cfg
  24. self.train = train
  25. # scale_kwargs = cfg.trainer.scale_kwargs
  26. # if scale_kwargs is not None and scale_kwargs['n_scale'] > 0:
  27. # scale_kwargs = {k: v for k, v in scale_kwargs.items()}
  28. # self.scales = get_scales(**scale_kwargs)
  29. # else:
  30. # self.scales = [(cfg.size, cfg.size)]
  31. # self.num = 0
  32. # self.batch_size_per_gpu = cfg.trainer.data.batch_size_per_gpu
  33. self.nb_classes = cfg.data.nb_classes
  34. self.data_all = self.samples
  35. self.length = len(self.data_all)
  36. # def reset_scale_transform(self):
  37. # scale_rand = random.choices(self.scales, k=1)[0]
  38. # scale_rand = scale_rand[0]
  39. # self.cfg.size = scale_rand
  40. # self.cfg.data.train_transforms[0]['input_size'] = scale_rand
  41. # self.transform = get_transforms(self.cfg, train=True, cfg_transforms=self.cfg.data.train_transforms)
  42. def __len__(self):
  43. return self.length
  44. def __getitem__(self, index):
  45. # if len(self.scales) > 1 and self.num % self.batch_size_per_gpu == 0:
  46. # self.reset_scale_transform()
  47. # self.num += 1
  48. path, target = self.data_all[index]
  49. img = self.loader(path)
  50. img = self.transform(img) if self.transform is not None else img
  51. target = self.target_transform(target) if self.target_transform is not None else target
  52. return {'img': img, 'target': target}
  53. class INatDataset(ImageFolder):
  54. def __init__(self, root, train=True, transform=None, year=2018):
  55. super(INatDataset, self).__init__(root=root)
  56. self.transform = transform
  57. self.year = year
  58. # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
  59. category = 'name'
  60. path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
  61. with open(path_json) as json_file:
  62. data = json.load(json_file)
  63. with open(os.path.join(root, 'categories.json')) as json_file:
  64. data_catg = json.load(json_file)
  65. path_json_for_targeter = os.path.join(root, f"train{year}.json")
  66. with open(path_json_for_targeter) as json_file:
  67. data_for_targeter = json.load(json_file)
  68. targeter = {}
  69. indexer = 0
  70. for elem in data_for_targeter['annotations']:
  71. king = []
  72. king.append(data_catg[int(elem['category_id'])][category])
  73. if king[0] not in targeter.keys():
  74. targeter[king[0]] = indexer
  75. indexer += 1
  76. self.nb_classes = len(targeter)
  77. self.samples = []
  78. for elem in data['images']:
  79. cut = elem['file_name'].split('/')
  80. target_current = int(cut[2])
  81. path_current = os.path.join(root, cut[0], cut[2], cut[3])
  82. categors = data_catg[target_current]
  83. target_current_true = targeter[categors[category]]
  84. self.samples.append((path_current, target_current_true))
  85. ### ImageNet21K
  86. # download link: https://opendatalab.com/ImageNet-21k/download
  87. # 01: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-00.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=N4NFPdRbLCQPYH6aT%2B9rISmeQ9Q%3D&response-content-type=application%2Foctet-stream
  88. # 02: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-01.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=07xGKO%2BN01MqrHnmJnrOlJwWrFU%3D&response-content-type=application%2Foctet-stream
  89. # 03: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-02.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=I6rWgueKX44byBdpvlne2YeZCgY%3D&response-content-type=application%2Foctet-stream
  90. # 04: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-03.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=uLMG9ndTodDAl81ltGNO73avRTM%3D&response-content-type=application%2Foctet-stream
  91. # 05: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-04.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=UCVQdB4Mei%2B2q1NUCe00rKwLKpM%3D&response-content-type=application%2Foctet-stream
  92. # 06: https://cdn.opendatalab.com/ImageNet-21k/raw/ImageNet21k-05.zip?Expires=1672760771&OSSAccessKeyId=LTAI5tCYLi1ZnJqYZX4tRk4q&Signature=bdz69KI85tFHmmsQydK5JWR%2BhcM%3D&response-content-type=application%2Foctet-stream
  93. # train-val split files: https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md
  94. @DATA.register_module
  95. class IN22KDataset(data.Dataset):
  96. def __init__(self, cfg, train=True, transform=None, target_transform=None):
  97. super(IN22KDataset, self).__init__()
  98. self.root = cfg.data.root
  99. self.loader = get_img_loader(cfg.data.loader_type)
  100. self.ann_path = f"{self.root}/ILSVRC2011fall_whole_map_{'train' if train else 'val'}.txt"
  101. self.cfg = cfg
  102. self.train = train
  103. self.transform = transform
  104. self.target_transform = target_transform
  105. # id & label: https://github.com/google-research/big_transfer/issues/7
  106. # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
  107. self.nb_classes = cfg.data.nb_classes
  108. self.data_all = json.load(open(self.ann_path))
  109. self.length = len(self.data_all)
  110. def __len__(self):
  111. return self.length
  112. def __getitem__(self, index):
  113. path, target = self.data_all[index]
  114. img = self.loader(f'{self.root}/{path}')
  115. img = self.transform(img) if self.transform is not None else img
  116. target = self.target_transform(target) if self.target_transform is not None else target
  117. return {'img': img, 'target': target}
  118. @DATA.register_module
  119. def Cifar10CLS(cfg, train=True, transforms=None):
  120. return datasets.CIFAR10(cfg.data.root, train=train, transform=transforms)
  121. @DATA.register_module
  122. def Cifar100CLS(cfg, train=True, transforms=None):
  123. return datasets.CIFAR100(cfg.data.root, train=train, transform=transforms)
  124. @DATA.register_module
  125. def INAT18CLS(cfg, train=True, transforms=None):
  126. return INatDataset(cfg.data.root, train=train, transforms=transforms, year=2018)
  127. @DATA.register_module
  128. def INAT19CLS(cfg, train=True, transforms=None):
  129. return INatDataset(cfg.data.root, train=train, transforms=transforms, year=2019)