export_onnx_model.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. from segment_anything import sam_model_registry
  7. from segment_anything.utils.onnx import SamOnnxModel
  8. import argparse
  9. import warnings
  10. try:
  11. import onnxruntime # type: ignore
  12. onnxruntime_exists = True
  13. except ImportError:
  14. onnxruntime_exists = False
  15. parser = argparse.ArgumentParser(
  16. description="Export the SAM prompt encoder and mask decoder to an ONNX model."
  17. )
  18. parser.add_argument(
  19. "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
  20. )
  21. parser.add_argument(
  22. "--output", type=str, required=True, help="The filename to save the ONNX model to."
  23. )
  24. parser.add_argument(
  25. "--model-type",
  26. type=str,
  27. required=True,
  28. help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
  29. )
  30. parser.add_argument(
  31. "--return-single-mask",
  32. action="store_true",
  33. help=(
  34. "If true, the exported ONNX model will only return the best mask, "
  35. "instead of returning multiple masks. For high resolution images "
  36. "this can improve runtime when upscaling masks is expensive."
  37. ),
  38. )
  39. parser.add_argument(
  40. "--opset",
  41. type=int,
  42. default=17,
  43. help="The ONNX opset version to use. Must be >=11",
  44. )
  45. parser.add_argument(
  46. "--quantize-out",
  47. type=str,
  48. default=None,
  49. help=(
  50. "If set, will quantize the model and save it with this name. "
  51. "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
  52. ),
  53. )
  54. parser.add_argument(
  55. "--gelu-approximate",
  56. action="store_true",
  57. help=(
  58. "Replace GELU operations with approximations using tanh. Useful "
  59. "for some runtimes that have slow or unimplemented erf ops, used in GELU."
  60. ),
  61. )
  62. parser.add_argument(
  63. "--use-stability-score",
  64. action="store_true",
  65. help=(
  66. "Replaces the model's predicted mask quality score with the stability "
  67. "score calculated on the low resolution masks using an offset of 1.0. "
  68. ),
  69. )
  70. parser.add_argument(
  71. "--return-extra-metrics",
  72. action="store_true",
  73. help=(
  74. "The model will return five results: (masks, scores, stability_scores, "
  75. "areas, low_res_logits) instead of the usual three. This can be "
  76. "significantly slower for high resolution outputs."
  77. ),
  78. )
  79. def run_export(
  80. model_type: str,
  81. checkpoint: str,
  82. output: str,
  83. opset: int,
  84. return_single_mask: bool,
  85. gelu_approximate: bool = False,
  86. use_stability_score: bool = False,
  87. return_extra_metrics=False,
  88. ):
  89. print("Loading model...")
  90. sam = sam_model_registry[model_type](checkpoint=checkpoint)
  91. onnx_model = SamOnnxModel(
  92. model=sam,
  93. return_single_mask=return_single_mask,
  94. use_stability_score=use_stability_score,
  95. return_extra_metrics=return_extra_metrics,
  96. )
  97. if gelu_approximate:
  98. for n, m in onnx_model.named_modules():
  99. if isinstance(m, torch.nn.GELU):
  100. m.approximate = "tanh"
  101. dynamic_axes = {
  102. "point_coords": {1: "num_points"},
  103. "point_labels": {1: "num_points"},
  104. }
  105. embed_dim = sam.prompt_encoder.embed_dim
  106. embed_size = sam.prompt_encoder.image_embedding_size
  107. mask_input_size = [4 * x for x in embed_size]
  108. dummy_inputs = {
  109. "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
  110. "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
  111. "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
  112. "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
  113. "has_mask_input": torch.tensor([1], dtype=torch.float),
  114. "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
  115. }
  116. _ = onnx_model(**dummy_inputs)
  117. output_names = ["masks", "iou_predictions", "low_res_masks"]
  118. with warnings.catch_warnings():
  119. warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
  120. warnings.filterwarnings("ignore", category=UserWarning)
  121. with open(output, "wb") as f:
  122. print(f"Exporting onnx model to {output}...")
  123. torch.onnx.export(
  124. onnx_model,
  125. tuple(dummy_inputs.values()),
  126. f,
  127. export_params=True,
  128. verbose=False,
  129. opset_version=opset,
  130. do_constant_folding=True,
  131. input_names=list(dummy_inputs.keys()),
  132. output_names=output_names,
  133. dynamic_axes=dynamic_axes,
  134. )
  135. if onnxruntime_exists:
  136. ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
  137. # set cpu provider default
  138. providers = ["CPUExecutionProvider"]
  139. ort_session = onnxruntime.InferenceSession(output, providers=providers)
  140. _ = ort_session.run(None, ort_inputs)
  141. print("Model has successfully been run with ONNXRuntime.")
  142. def to_numpy(tensor):
  143. return tensor.cpu().numpy()
  144. if __name__ == "__main__":
  145. args = parser.parse_args()
  146. run_export(
  147. model_type=args.model_type,
  148. checkpoint=args.checkpoint,
  149. output=args.output,
  150. opset=args.opset,
  151. return_single_mask=args.return_single_mask,
  152. gelu_approximate=args.gelu_approximate,
  153. use_stability_score=args.use_stability_score,
  154. return_extra_metrics=args.return_extra_metrics,
  155. )
  156. if args.quantize_out is not None:
  157. assert onnxruntime_exists, "onnxruntime is required to quantize the model."
  158. from onnxruntime.quantization import QuantType # type: ignore
  159. from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
  160. print(f"Quantizing model and writing to {args.quantize_out}...")
  161. quantize_dynamic(
  162. model_input=args.output,
  163. model_output=args.quantize_out,
  164. optimize_model=True,
  165. per_channel=False,
  166. reduce_range=False,
  167. weight_type=QuantType.QUInt8,
  168. )
  169. print("Done!")