predictor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import numpy as np
  6. import torch
  7. from segment_anything.modeling import Sam
  8. from typing import Optional, Tuple
  9. from .utils.transforms import ResizeLongestSide
  10. class SamPredictor:
  11. def __init__(
  12. self,
  13. sam_model: Sam,
  14. ) -> None:
  15. """
  16. Uses SAM to calculate the image embedding for an image, and then
  17. allow repeated, efficient mask prediction given prompts.
  18. Arguments:
  19. sam_model (Sam): The model to use for mask prediction.
  20. """
  21. super().__init__()
  22. self.model = sam_model
  23. self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  24. self.reset_image()
  25. def set_image(
  26. self,
  27. image: np.ndarray,
  28. image_format: str = "RGB",
  29. ) -> None:
  30. """
  31. Calculates the image embeddings for the provided image, allowing
  32. masks to be predicted with the 'predict' method.
  33. Arguments:
  34. image (np.ndarray): The image for calculating masks. Expects an
  35. image in HWC uint8 format, with pixel values in [0, 255].
  36. image_format (str): The color format of the image, in ['RGB', 'BGR'].
  37. """
  38. assert image_format in [
  39. "RGB",
  40. "BGR",
  41. ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
  42. if image_format != self.model.image_format:
  43. image = image[..., ::-1]
  44. # Transform the image to the form expected by the model
  45. input_image = self.transform.apply_image(image)
  46. input_image_torch = torch.as_tensor(input_image, device=self.device)
  47. input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
  48. self.set_torch_image(input_image_torch, image.shape[:2])
  49. @torch.no_grad()
  50. def set_torch_image(
  51. self,
  52. transformed_image: torch.Tensor,
  53. original_image_size: Tuple[int, ...],
  54. ) -> None:
  55. """
  56. Calculates the image embeddings for the provided image, allowing
  57. masks to be predicted with the 'predict' method. Expects the input
  58. image to be already transformed to the format expected by the model.
  59. Arguments:
  60. transformed_image (torch.Tensor): The input image, with shape
  61. 1x3xHxW, which has been transformed with ResizeLongestSide.
  62. original_image_size (tuple(int, int)): The size of the image
  63. before transformation, in (H, W) format.
  64. """
  65. assert (
  66. len(transformed_image.shape) == 4
  67. and transformed_image.shape[1] == 3
  68. and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
  69. ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
  70. self.reset_image()
  71. self.original_size = original_image_size
  72. self.input_size = tuple(transformed_image.shape[-2:])
  73. input_image = self.model.preprocess(transformed_image)
  74. self.features = self.model.image_encoder(input_image)
  75. self.is_image_set = True
  76. def predict(
  77. self,
  78. point_coords: Optional[np.ndarray] = None,
  79. point_labels: Optional[np.ndarray] = None,
  80. box: Optional[np.ndarray] = None,
  81. mask_input: Optional[np.ndarray] = None,
  82. multimask_output: bool = True,
  83. return_logits: bool = False,
  84. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  85. """
  86. Predict masks for the given input prompts, using the currently set image.
  87. Arguments:
  88. point_coords (np.ndarray or None): A Nx2 array of point prompts to the
  89. model. Each point is in (X,Y) in pixels.
  90. point_labels (np.ndarray or None): A length N array of labels for the
  91. point prompts. 1 indicates a foreground point and 0 indicates a
  92. background point.
  93. box (np.ndarray or None): A length 4 array given a box prompt to the
  94. model, in XYXY format.
  95. mask_input (np.ndarray): A low resolution mask input to the model, typically
  96. coming from a previous prediction iteration. Has form 1xHxW, where
  97. for SAM, H=W=256.
  98. multimask_output (bool): If true, the model will return three masks.
  99. For ambiguous input prompts (such as a single click), this will often
  100. produce better masks than a single prediction. If only a single
  101. mask is needed, the model's predicted quality score can be used
  102. to select the best mask. For non-ambiguous prompts, such as multiple
  103. input prompts, multimask_output=False can give better results.
  104. return_logits (bool): If true, returns un-thresholded masks logits
  105. instead of a binary mask.
  106. Returns:
  107. (np.ndarray): The output masks in CxHxW format, where C is the
  108. number of masks, and (H, W) is the original image size.
  109. (np.ndarray): An array of length C containing the model's
  110. predictions for the quality of each mask.
  111. (np.ndarray): An array of shape CxHxW, where C is the number
  112. of masks and H=W=256. These low resolution logits can be passed to
  113. a subsequent iteration as mask input.
  114. """
  115. if not self.is_image_set:
  116. raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
  117. # Transform input prompts
  118. coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
  119. if point_coords is not None:
  120. assert (
  121. point_labels is not None
  122. ), "point_labels must be supplied if point_coords is supplied."
  123. point_coords = self.transform.apply_coords(point_coords, self.original_size)
  124. coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
  125. labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
  126. coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
  127. if box is not None:
  128. box = self.transform.apply_boxes(box, self.original_size)
  129. box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
  130. box_torch = box_torch[None, :]
  131. if mask_input is not None:
  132. mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
  133. mask_input_torch = mask_input_torch[None, :, :, :]
  134. masks, iou_predictions, low_res_masks = self.predict_torch(
  135. coords_torch,
  136. labels_torch,
  137. box_torch,
  138. mask_input_torch,
  139. multimask_output,
  140. return_logits=return_logits,
  141. )
  142. masks_np = masks[0].detach().cpu().numpy()
  143. iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
  144. low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
  145. return masks_np, iou_predictions_np, low_res_masks_np
  146. @torch.no_grad()
  147. def predict_torch(
  148. self,
  149. point_coords: Optional[torch.Tensor],
  150. point_labels: Optional[torch.Tensor],
  151. boxes: Optional[torch.Tensor] = None,
  152. mask_input: Optional[torch.Tensor] = None,
  153. multimask_output: bool = True,
  154. return_logits: bool = False,
  155. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  156. """
  157. Predict masks for the given input prompts, using the currently set image.
  158. Input prompts are batched torch tensors and are expected to already be
  159. transformed to the input frame using ResizeLongestSide.
  160. Arguments:
  161. point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
  162. model. Each point is in (X,Y) in pixels.
  163. point_labels (torch.Tensor or None): A BxN array of labels for the
  164. point prompts. 1 indicates a foreground point and 0 indicates a
  165. background point.
  166. boxes (np.ndarray or None): A Bx4 array given a box prompt to the
  167. model, in XYXY format.
  168. mask_input (np.ndarray): A low resolution mask input to the model, typically
  169. coming from a previous prediction iteration. Has form Bx1xHxW, where
  170. for SAM, H=W=256. Masks returned by a previous iteration of the
  171. predict method do not need further transformation.
  172. multimask_output (bool): If true, the model will return three masks.
  173. For ambiguous input prompts (such as a single click), this will often
  174. produce better masks than a single prediction. If only a single
  175. mask is needed, the model's predicted quality score can be used
  176. to select the best mask. For non-ambiguous prompts, such as multiple
  177. input prompts, multimask_output=False can give better results.
  178. return_logits (bool): If true, returns un-thresholded masks logits
  179. instead of a binary mask.
  180. Returns:
  181. (torch.Tensor): The output masks in BxCxHxW format, where C is the
  182. number of masks, and (H, W) is the original image size.
  183. (torch.Tensor): An array of shape BxC containing the model's
  184. predictions for the quality of each mask.
  185. (torch.Tensor): An array of shape BxCxHxW, where C is the number
  186. of masks and H=W=256. These low res logits can be passed to
  187. a subsequent iteration as mask input.
  188. """
  189. if not self.is_image_set:
  190. raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
  191. if point_coords is not None:
  192. points = (point_coords, point_labels)
  193. else:
  194. points = None
  195. # Embed prompts
  196. sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  197. points=points,
  198. boxes=boxes,
  199. masks=mask_input,
  200. )
  201. # Predict masks
  202. low_res_masks, iou_predictions = self.model.mask_decoder(
  203. image_embeddings=self.features,
  204. image_pe=self.model.prompt_encoder.get_dense_pe(),
  205. sparse_prompt_embeddings=sparse_embeddings,
  206. dense_prompt_embeddings=dense_embeddings,
  207. multimask_output=multimask_output,
  208. )
  209. # Upscale the masks to the original image resolution
  210. masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
  211. if not return_logits:
  212. masks = masks > self.model.mask_threshold
  213. return masks, iou_predictions, low_res_masks
  214. def get_image_embedding(self) -> torch.Tensor:
  215. """
  216. Returns the image embeddings for the currently set image, with
  217. shape 1xCxHxW, where C is the embedding dimension and (H,W) are
  218. the embedding spatial dimension of SAM (typically C=256, H=W=64).
  219. """
  220. if not self.is_image_set:
  221. raise RuntimeError(
  222. "An image must be set with .set_image(...) to generate an embedding."
  223. )
  224. assert self.features is not None, "Features must exist if an image has been set."
  225. return self.features
  226. @property
  227. def device(self) -> torch.device:
  228. return self.model.device
  229. def reset_image(self) -> None:
  230. """Resets the currently set image."""
  231. self.is_image_set = False
  232. self.features = None
  233. self.orig_h = None
  234. self.orig_w = None
  235. self.input_h = None
  236. self.input_w = None