onnx.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import torch.nn as nn
  7. from torch.nn import functional as F
  8. from typing import Tuple
  9. from ..modeling import Sam
  10. from .amg import calculate_stability_score
  11. class SamOnnxModel(nn.Module):
  12. """
  13. This model should not be called directly, but is used in ONNX export.
  14. It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
  15. with some functions modified to enable model tracing. Also supports extra
  16. options controlling what information. See the ONNX export script for details.
  17. """
  18. def __init__(
  19. self,
  20. model: Sam,
  21. return_single_mask: bool,
  22. use_stability_score: bool = False,
  23. return_extra_metrics: bool = False,
  24. ) -> None:
  25. super().__init__()
  26. self.mask_decoder = model.mask_decoder
  27. self.model = model
  28. self.img_size = model.image_encoder.img_size
  29. self.return_single_mask = return_single_mask
  30. self.use_stability_score = use_stability_score
  31. self.stability_score_offset = 1.0
  32. self.return_extra_metrics = return_extra_metrics
  33. @staticmethod
  34. def resize_longest_image_size(
  35. input_image_size: torch.Tensor, longest_side: int
  36. ) -> torch.Tensor:
  37. input_image_size = input_image_size.to(torch.float32)
  38. scale = longest_side / torch.max(input_image_size)
  39. transformed_size = scale * input_image_size
  40. transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
  41. return transformed_size
  42. def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
  43. point_coords = point_coords + 0.5
  44. point_coords = point_coords / self.img_size
  45. point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
  46. point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
  47. point_embedding = point_embedding * (point_labels != -1)
  48. point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
  49. point_labels == -1
  50. )
  51. for i in range(self.model.prompt_encoder.num_point_embeddings):
  52. point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
  53. i
  54. ].weight * (point_labels == i)
  55. return point_embedding
  56. def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
  57. mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
  58. mask_embedding = mask_embedding + (
  59. 1 - has_mask_input
  60. ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
  61. return mask_embedding
  62. def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
  63. masks = F.interpolate(
  64. masks,
  65. size=(self.img_size, self.img_size),
  66. mode="bilinear",
  67. align_corners=False,
  68. )
  69. prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
  70. masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
  71. orig_im_size = orig_im_size.to(torch.int64)
  72. h, w = orig_im_size[0], orig_im_size[1]
  73. masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
  74. return masks
  75. def select_masks(
  76. self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
  77. ) -> Tuple[torch.Tensor, torch.Tensor]:
  78. # Determine if we should return the multiclick mask or not from the number of points.
  79. # The reweighting is used to avoid control flow.
  80. score_reweight = torch.tensor(
  81. [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
  82. ).to(iou_preds.device)
  83. score = iou_preds + (num_points - 2.5) * score_reweight
  84. best_idx = torch.argmax(score, dim=1)
  85. masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
  86. iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
  87. return masks, iou_preds
  88. @torch.no_grad()
  89. def forward(
  90. self,
  91. image_embeddings: torch.Tensor,
  92. point_coords: torch.Tensor,
  93. point_labels: torch.Tensor,
  94. mask_input: torch.Tensor,
  95. has_mask_input: torch.Tensor,
  96. orig_im_size: torch.Tensor,
  97. ):
  98. sparse_embedding = self._embed_points(point_coords, point_labels)
  99. dense_embedding = self._embed_masks(mask_input, has_mask_input)
  100. masks, scores = self.model.mask_decoder.predict_masks(
  101. image_embeddings=image_embeddings,
  102. image_pe=self.model.prompt_encoder.get_dense_pe(),
  103. sparse_prompt_embeddings=sparse_embedding,
  104. dense_prompt_embeddings=dense_embedding,
  105. )
  106. if self.use_stability_score:
  107. scores = calculate_stability_score(
  108. masks, self.model.mask_threshold, self.stability_score_offset
  109. )
  110. if self.return_single_mask:
  111. masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
  112. upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
  113. if self.return_extra_metrics:
  114. stability_scores = calculate_stability_score(
  115. upscaled_masks, self.model.mask_threshold, self.stability_score_offset
  116. )
  117. areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
  118. return upscaled_masks, scores, stability_scores, areas, masks
  119. return upscaled_masks, scores, masks