data.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import cv2
  5. from PIL import Image
  6. import accimage
  7. import torchvision
  8. import torchvision.transforms as transforms
  9. from skimage import color
  10. def pil_loader(path):
  11. return Image.open(path).convert('RGB')
  12. def accimage_loader(path):
  13. return accimage.Image(path)
  14. def get_img_loader(loader_type):
  15. if loader_type == 'pil':
  16. return pil_loader
  17. elif loader_type == 'accimage':
  18. torchvision.set_image_backend('accimage')
  19. return accimage_loader
  20. else:
  21. raise ValueError('invalid image loader type: {}'.format(loader_type))
  22. # ---------- for visualization ----------
  23. def rgb_vis(img, mean, std):
  24. """
  25. Args:
  26. img : tensor, rgb[-1.0, 1.0], [3, H, W]
  27. Returns:
  28. img : numpy, rgb[0, 255]
  29. """
  30. img = img.data.cpu().numpy()
  31. for i in range(3):
  32. img[i, :, :] = img[i, :, :] * std[i] + mean[i]
  33. img = np.transpose(img, (1, 2, 0)) * 255
  34. img = np.clip(img, 0, 255)
  35. img = img.astype(np.uint8)
  36. return img
  37. def rgbs_vis(imgs, mean, std):
  38. """
  39. Args:
  40. img : tensor, rgb[-1.0, 1.0], [B, 3, H, W]
  41. Returns:
  42. img : tensor, rgb[0.0, 1.0]
  43. """
  44. bs = imgs.shape[0]
  45. imgs_tensor = []
  46. for i in range(bs):
  47. img = rgb_vis(imgs[i], mean, std)
  48. img = Image.fromarray(img)
  49. img = transforms.ToTensor()(img)
  50. imgs_tensor.append(img)
  51. imgs_tensor = torch.stack(imgs_tensor, dim=0)
  52. return imgs_tensor
  53. # ---------- for multi-scale training ----------
  54. def make_divisible(v, divisor=8, min_value=None):
  55. if min_value is None:
  56. min_value = divisor
  57. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  58. # Make sure that round down does not go down by more than 10%.
  59. if new_v < 0.9 * v:
  60. new_v += divisor
  61. return new_v
  62. def get_scales(n_scale, base_h, base_w, min_h, max_h, min_w, max_w, check_scale_div_factor=32):
  63. hs = list(np.linspace(min_h, max_h, n_scale))
  64. if base_h not in hs:
  65. hs.append(base_h)
  66. ws = list(np.linspace(min_w, max_w, n_scale))
  67. if base_w not in ws:
  68. ws.append(base_w)
  69. scales = set()
  70. for h, w in zip(hs, ws):
  71. h = make_divisible(h, check_scale_div_factor)
  72. w = make_divisible(w, check_scale_div_factor)
  73. scales.add((h, w))
  74. scales = list(scales)
  75. return scales