prompt_encoder.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import numpy as np
  6. import torch
  7. from torch import nn
  8. from typing import Any, Optional, Tuple, Type
  9. from .common import LayerNorm2d
  10. class PromptEncoder(nn.Module):
  11. def __init__(
  12. self,
  13. embed_dim: int,
  14. image_embedding_size: Tuple[int, int],
  15. input_image_size: Tuple[int, int],
  16. mask_in_chans: int,
  17. activation: Type[nn.Module] = nn.GELU,
  18. ) -> None:
  19. """
  20. Encodes prompts for input to SAM's mask decoder.
  21. Arguments:
  22. embed_dim (int): The prompts' embedding dimension
  23. image_embedding_size (tuple(int, int)): The spatial size of the
  24. image embedding, as (H, W).
  25. input_image_size (int): The padded size of the image as input
  26. to the image encoder, as (H, W).
  27. mask_in_chans (int): The number of hidden channels used for
  28. encoding input masks.
  29. activation (nn.Module): The activation to use when encoding
  30. input masks.
  31. """
  32. super().__init__()
  33. self.embed_dim = embed_dim
  34. self.input_image_size = input_image_size
  35. self.image_embedding_size = image_embedding_size
  36. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  37. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  38. point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
  39. self.point_embeddings = nn.ModuleList(point_embeddings)
  40. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  41. self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
  42. self.mask_downscaling = nn.Sequential(
  43. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  44. LayerNorm2d(mask_in_chans // 4),
  45. activation(),
  46. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  47. LayerNorm2d(mask_in_chans),
  48. activation(),
  49. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  50. )
  51. self.no_mask_embed = nn.Embedding(1, embed_dim)
  52. def get_dense_pe(self) -> torch.Tensor:
  53. """
  54. Returns the positional encoding used to encode point prompts,
  55. applied to a dense set of points the shape of the image encoding.
  56. Returns:
  57. torch.Tensor: Positional encoding with shape
  58. 1x(embed_dim)x(embedding_h)x(embedding_w)
  59. """
  60. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  61. def _embed_points(
  62. self,
  63. points: torch.Tensor,
  64. labels: torch.Tensor,
  65. pad: bool,
  66. ) -> torch.Tensor:
  67. """Embeds point prompts."""
  68. points = points + 0.5 # Shift to center of pixel
  69. if pad:
  70. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  71. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  72. points = torch.cat([points, padding_point], dim=1)
  73. labels = torch.cat([labels, padding_label], dim=1)
  74. point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
  75. point_embedding[labels == -1] = 0.0
  76. point_embedding[labels == -1] += self.not_a_point_embed.weight
  77. point_embedding[labels == 0] += self.point_embeddings[0].weight
  78. point_embedding[labels == 1] += self.point_embeddings[1].weight
  79. return point_embedding
  80. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  81. """Embeds box prompts."""
  82. boxes = boxes + 0.5 # Shift to center of pixel
  83. coords = boxes.reshape(-1, 2, 2)
  84. corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
  85. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  86. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  87. return corner_embedding
  88. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  89. """Embeds mask inputs."""
  90. mask_embedding = self.mask_downscaling(masks)
  91. return mask_embedding
  92. def _get_batch_size(
  93. self,
  94. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  95. boxes: Optional[torch.Tensor],
  96. masks: Optional[torch.Tensor],
  97. ) -> int:
  98. """
  99. Gets the batch size of the output given the batch size of the input prompts.
  100. """
  101. if points is not None:
  102. return points[0].shape[0]
  103. elif boxes is not None:
  104. return boxes.shape[0]
  105. elif masks is not None:
  106. return masks.shape[0]
  107. else:
  108. return 1
  109. def _get_device(self) -> torch.device:
  110. return self.point_embeddings[0].weight.device
  111. def forward(
  112. self,
  113. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  114. boxes: Optional[torch.Tensor],
  115. masks: Optional[torch.Tensor],
  116. ) -> Tuple[torch.Tensor, torch.Tensor]:
  117. """
  118. Embeds different types of prompts, returning both sparse and dense
  119. embeddings.
  120. Arguments:
  121. points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
  122. and labels to embed.
  123. boxes (torch.Tensor or none): boxes to embed
  124. masks (torch.Tensor or none): masks to embed
  125. Returns:
  126. torch.Tensor: sparse embeddings for the points and boxes, with shape
  127. BxNx(embed_dim), where N is determined by the number of input points
  128. and boxes.
  129. torch.Tensor: dense embeddings for the masks, in the shape
  130. Bx(embed_dim)x(embed_H)x(embed_W)
  131. """
  132. bs = self._get_batch_size(points, boxes, masks)
  133. sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
  134. if points is not None:
  135. coords, labels = points
  136. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  137. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  138. if boxes is not None:
  139. box_embeddings = self._embed_boxes(boxes)
  140. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  141. if masks is not None:
  142. dense_embeddings = self._embed_masks(masks)
  143. else:
  144. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  145. bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  146. )
  147. return sparse_embeddings, dense_embeddings
  148. class PositionEmbeddingRandom(nn.Module):
  149. """
  150. Positional encoding using random spatial frequencies.
  151. """
  152. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  153. super().__init__()
  154. if scale is None or scale <= 0.0:
  155. scale = 1.0
  156. self.register_buffer(
  157. "positional_encoding_gaussian_matrix",
  158. scale * torch.randn((2, num_pos_feats)),
  159. )
  160. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  161. """Positionally encode points that are normalized to [0,1]."""
  162. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  163. coords = 2 * coords - 1
  164. coords = coords @ self.positional_encoding_gaussian_matrix
  165. coords = 2 * np.pi * coords
  166. # outputs d_1 x ... x d_n x C shape
  167. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  168. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  169. """Generate positional encoding for a grid of the specified size."""
  170. h, w = size
  171. device: Any = self.positional_encoding_gaussian_matrix.device
  172. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  173. y_embed = grid.cumsum(dim=0) - 0.5
  174. x_embed = grid.cumsum(dim=1) - 0.5
  175. y_embed = y_embed / h
  176. x_embed = x_embed / w
  177. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  178. return pe.permute(2, 0, 1) # C x H x W
  179. def forward_with_coords(
  180. self, coords_input: torch.Tensor, image_size: Tuple[int, int]
  181. ) -> torch.Tensor:
  182. """Positionally encode points that are not normalized to [0,1]."""
  183. coords = coords_input.clone()
  184. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  185. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  186. return self._pe_encoding(coords.to(torch.float)) # B x N x C