transforms.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms.functional as F
  4. from torchvision import transforms
  5. from timm.data import create_transform
  6. import cv2
  7. import numpy as np
  8. from PIL import Image
  9. from . import TRANSFORMS
  10. # for torchvision
  11. tv_tran = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
  12. "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
  13. "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
  14. "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
  15. "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
  16. "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
  17. for tv_tran_name in tv_tran:
  18. tv_transform = getattr(transforms, tv_tran_name, None)
  19. TRANSFORMS.register_module(tv_transform, name=tv_tran_name) if tv_transform else None
  20. # for timm
  21. TRANSFORMS.register_module(create_transform, name='timm_create_transform')
  22. class vt_TransBase(object):
  23. def __init__(self):
  24. pass
  25. def pre_process(self):
  26. pass
  27. def __call__(self, img):
  28. pass
  29. @TRANSFORMS.register_module
  30. class vt_Identity(vt_TransBase):
  31. def __call__(self, img):
  32. return img
  33. @TRANSFORMS.register_module
  34. class vt_Resize(vt_TransBase):
  35. """
  36. Args:
  37. size : h | (h, w)
  38. img : PIL Image
  39. Returns:
  40. PIL Image
  41. """
  42. def __init__(self, size, interpolation=F.InterpolationMode.BICUBIC):
  43. super().__init__()
  44. self.size = size
  45. self.interpolation = interpolation
  46. def __call__(self, img):
  47. return F.resize(img, self.size, self.interpolation)
  48. @TRANSFORMS.register_module
  49. class vt_Compose(vt_TransBase):
  50. def __init__(self, transforms):
  51. super().__init__()
  52. self.transforms = transforms
  53. def pre_process(self):
  54. for t in self.transforms:
  55. t.pre_process()
  56. def __call__(self, img):
  57. for t in self.transforms:
  58. img = t(img)
  59. return img
  60. if __name__ == '__main__':
  61. import matplotlib.pyplot as plt
  62. from skimage import color
  63. import torch.nn.functional as F1
  64. path = '../ttt/ttt.png'
  65. img = Image.open(path).convert('RGB')
  66. train_transforms = list()
  67. train_transforms.append(Resize((200, 300)))
  68. train_transforms.append(Flip(p=0.5, flipCode=1))
  69. train_transforms.append(ToTensor())
  70. train_transforms.append(Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
  71. train_transforms = Compose(train_transforms)
  72. img1 = train_transforms(img)
  73. print(img1.min(), img1.max())