sam.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. from torch import nn
  7. from torch.nn import functional as F
  8. from typing import Any, Dict, List, Tuple
  9. from .image_encoder import ImageEncoderViT
  10. from .mask_decoder import MaskDecoder
  11. from .prompt_encoder import PromptEncoder
  12. class Sam(nn.Module):
  13. mask_threshold: float = 0.0
  14. image_format: str = "RGB"
  15. def __init__(
  16. self,
  17. image_encoder: ImageEncoderViT,
  18. prompt_encoder: PromptEncoder,
  19. mask_decoder: MaskDecoder,
  20. pixel_mean: List[float] = [123.675, 116.28, 103.53],
  21. pixel_std: List[float] = [58.395, 57.12, 57.375],
  22. ) -> None:
  23. """
  24. SAM predicts object masks from an image and input prompts.
  25. Arguments:
  26. image_encoder (ImageEncoderViT): The backbone used to encode the
  27. image into image embeddings that allow for efficient mask prediction.
  28. prompt_encoder (PromptEncoder): Encodes various types of input prompts.
  29. mask_decoder (MaskDecoder): Predicts masks from the image embeddings
  30. and encoded prompts.
  31. pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
  32. pixel_std (list(float)): Std values for normalizing pixels in the input image.
  33. """
  34. super().__init__()
  35. self.image_encoder = image_encoder
  36. self.prompt_encoder = prompt_encoder
  37. self.mask_decoder = mask_decoder
  38. self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
  39. self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
  40. @property
  41. def device(self) -> Any:
  42. return self.pixel_mean.device
  43. @torch.no_grad()
  44. def forward(
  45. self,
  46. batched_input: List[Dict[str, Any]],
  47. multimask_output: bool,
  48. ) -> List[Dict[str, torch.Tensor]]:
  49. """
  50. Predicts masks end-to-end from provided images and prompts.
  51. If prompts are not known in advance, using SamPredictor is
  52. recommended over calling the model directly.
  53. Arguments:
  54. batched_input (list(dict)): A list over input images, each a
  55. dictionary with the following keys. A prompt key can be
  56. excluded if it is not present.
  57. 'image': The image as a torch tensor in 3xHxW format,
  58. already transformed for input to the model.
  59. 'original_size': (tuple(int, int)) The original size of
  60. the image before transformation, as (H, W).
  61. 'point_coords': (torch.Tensor) Batched point prompts for
  62. this image, with shape BxNx2. Already transformed to the
  63. input frame of the model.
  64. 'point_labels': (torch.Tensor) Batched labels for point prompts,
  65. with shape BxN.
  66. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
  67. Already transformed to the input frame of the model.
  68. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
  69. in the form Bx1xHxW.
  70. multimask_output (bool): Whether the model should predict multiple
  71. disambiguating masks, or return a single mask.
  72. Returns:
  73. (list(dict)): A list over input images, where each element is
  74. as dictionary with the following keys.
  75. 'masks': (torch.Tensor) Batched binary mask predictions,
  76. with shape BxCxHxW, where B is the number of input prompts,
  77. C is determined by multimask_output, and (H, W) is the
  78. original size of the image.
  79. 'iou_predictions': (torch.Tensor) The model's predictions
  80. of mask quality, in shape BxC.
  81. 'low_res_logits': (torch.Tensor) Low resolution logits with
  82. shape BxCxHxW, where H=W=256. Can be passed as mask input
  83. to subsequent iterations of prediction.
  84. """
  85. input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
  86. image_embeddings = self.image_encoder(input_images)
  87. outputs = []
  88. for image_record, curr_embedding in zip(batched_input, image_embeddings):
  89. if "point_coords" in image_record:
  90. points = (image_record["point_coords"], image_record["point_labels"])
  91. else:
  92. points = None
  93. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  94. points=points,
  95. boxes=image_record.get("boxes", None),
  96. masks=image_record.get("mask_inputs", None),
  97. )
  98. low_res_masks, iou_predictions = self.mask_decoder(
  99. image_embeddings=curr_embedding.unsqueeze(0),
  100. image_pe=self.prompt_encoder.get_dense_pe(),
  101. sparse_prompt_embeddings=sparse_embeddings,
  102. dense_prompt_embeddings=dense_embeddings,
  103. multimask_output=multimask_output,
  104. )
  105. masks = self.postprocess_masks(
  106. low_res_masks,
  107. input_size=image_record["image"].shape[-2:],
  108. original_size=image_record["original_size"],
  109. )
  110. masks = masks > self.mask_threshold
  111. outputs.append(
  112. {
  113. "masks": masks,
  114. "iou_predictions": iou_predictions,
  115. "low_res_logits": low_res_masks,
  116. }
  117. )
  118. return outputs
  119. def postprocess_masks(
  120. self,
  121. masks: torch.Tensor,
  122. input_size: Tuple[int, ...],
  123. original_size: Tuple[int, ...],
  124. ) -> torch.Tensor:
  125. """
  126. Remove padding and upscale masks to the original image size.
  127. Arguments:
  128. masks (torch.Tensor): Batched masks from the mask_decoder,
  129. in BxCxHxW format.
  130. input_size (tuple(int, int)): The size of the image input to the
  131. model, in (H, W) format. Used to remove padding.
  132. original_size (tuple(int, int)): The original size of the image
  133. before resizing for input to the model, in (H, W) format.
  134. Returns:
  135. (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
  136. is given by original_size.
  137. """
  138. masks = F.interpolate(
  139. masks,
  140. (self.image_encoder.img_size, self.image_encoder.img_size),
  141. mode="bilinear",
  142. align_corners=False,
  143. )
  144. masks = masks[..., : input_size[0], : input_size[1]]
  145. masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
  146. return masks
  147. def preprocess(self, x: torch.Tensor) -> torch.Tensor:
  148. """Normalize pixel values and pad to a square input."""
  149. # Normalize colors
  150. x = (x - self.pixel_mean) / self.pixel_std
  151. # Pad
  152. h, w = x.shape[-2:]
  153. padh = self.image_encoder.img_size - h
  154. padw = self.image_encoder.img_size - w
  155. x = F.pad(x, (0, padw, 0, padh))
  156. return x