transforms.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import numpy as np
  6. import torch
  7. from torch.nn import functional as F
  8. from torchvision.transforms.functional import resize, to_pil_image # type: ignore
  9. from copy import deepcopy
  10. from typing import Tuple
  11. class ResizeLongestSide:
  12. """
  13. Resizes images to the longest side 'target_length', as well as provides
  14. methods for resizing coordinates and boxes. Provides methods for
  15. transforming both numpy array and batched torch tensors.
  16. """
  17. def __init__(self, target_length: int) -> None:
  18. self.target_length = target_length
  19. def apply_image(self, image: np.ndarray) -> np.ndarray:
  20. """
  21. Expects a numpy array with shape HxWxC in uint8 format.
  22. """
  23. target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
  24. return np.array(resize(to_pil_image(image), target_size))
  25. def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
  26. """
  27. Expects a numpy array of length 2 in the final dimension. Requires the
  28. original image size in (H, W) format.
  29. """
  30. old_h, old_w = original_size
  31. new_h, new_w = self.get_preprocess_shape(
  32. original_size[0], original_size[1], self.target_length
  33. )
  34. coords = deepcopy(coords).astype(float)
  35. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  36. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  37. return coords
  38. def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
  39. """
  40. Expects a numpy array shape Bx4. Requires the original image size
  41. in (H, W) format.
  42. """
  43. boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
  44. return boxes.reshape(-1, 4)
  45. def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
  46. """
  47. Expects batched images with shape BxCxHxW and float format. This
  48. transformation may not exactly match apply_image. apply_image is
  49. the transformation expected by the model.
  50. """
  51. # Expects an image in BCHW format. May not exactly match apply_image.
  52. target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
  53. return F.interpolate(
  54. image, target_size, mode="bilinear", align_corners=False, antialias=True
  55. )
  56. def apply_coords_torch(
  57. self, coords: torch.Tensor, original_size: Tuple[int, ...]
  58. ) -> torch.Tensor:
  59. """
  60. Expects a torch tensor with length 2 in the last dimension. Requires the
  61. original image size in (H, W) format.
  62. """
  63. old_h, old_w = original_size
  64. new_h, new_w = self.get_preprocess_shape(
  65. original_size[0], original_size[1], self.target_length
  66. )
  67. coords = deepcopy(coords).to(torch.float)
  68. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  69. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  70. return coords
  71. def apply_boxes_torch(
  72. self, boxes: torch.Tensor, original_size: Tuple[int, ...]
  73. ) -> torch.Tensor:
  74. """
  75. Expects a torch tensor with shape Bx4. Requires the original image
  76. size in (H, W) format.
  77. """
  78. boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
  79. return boxes.reshape(-1, 4)
  80. @staticmethod
  81. def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
  82. """
  83. Compute the output size given input size and target long side length.
  84. """
  85. scale = long_side_length * 1.0 / max(oldh, oldw)
  86. newh, neww = oldh * scale, oldw * scale
  87. neww = int(neww + 0.5)
  88. newh = int(newh + 0.5)
  89. return (newh, neww)