| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- from segment_anything import sam_model_registry
- from segment_anything.utils.onnx import SamOnnxModel
- import argparse
- import warnings
- try:
- import onnxruntime # type: ignore
- onnxruntime_exists = True
- except ImportError:
- onnxruntime_exists = False
- parser = argparse.ArgumentParser(
- description="Export the SAM prompt encoder and mask decoder to an ONNX model."
- )
- parser.add_argument(
- "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
- )
- parser.add_argument(
- "--output", type=str, required=True, help="The filename to save the ONNX model to."
- )
- parser.add_argument(
- "--model-type",
- type=str,
- required=True,
- help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
- )
- parser.add_argument(
- "--return-single-mask",
- action="store_true",
- help=(
- "If true, the exported ONNX model will only return the best mask, "
- "instead of returning multiple masks. For high resolution images "
- "this can improve runtime when upscaling masks is expensive."
- ),
- )
- parser.add_argument(
- "--opset",
- type=int,
- default=17,
- help="The ONNX opset version to use. Must be >=11",
- )
- parser.add_argument(
- "--quantize-out",
- type=str,
- default=None,
- help=(
- "If set, will quantize the model and save it with this name. "
- "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
- ),
- )
- parser.add_argument(
- "--gelu-approximate",
- action="store_true",
- help=(
- "Replace GELU operations with approximations using tanh. Useful "
- "for some runtimes that have slow or unimplemented erf ops, used in GELU."
- ),
- )
- parser.add_argument(
- "--use-stability-score",
- action="store_true",
- help=(
- "Replaces the model's predicted mask quality score with the stability "
- "score calculated on the low resolution masks using an offset of 1.0. "
- ),
- )
- parser.add_argument(
- "--return-extra-metrics",
- action="store_true",
- help=(
- "The model will return five results: (masks, scores, stability_scores, "
- "areas, low_res_logits) instead of the usual three. This can be "
- "significantly slower for high resolution outputs."
- ),
- )
- def run_export(
- model_type: str,
- checkpoint: str,
- output: str,
- opset: int,
- return_single_mask: bool,
- gelu_approximate: bool = False,
- use_stability_score: bool = False,
- return_extra_metrics=False,
- ):
- print("Loading model...")
- sam = sam_model_registry[model_type](checkpoint=checkpoint)
- onnx_model = SamOnnxModel(
- model=sam,
- return_single_mask=return_single_mask,
- use_stability_score=use_stability_score,
- return_extra_metrics=return_extra_metrics,
- )
- if gelu_approximate:
- for n, m in onnx_model.named_modules():
- if isinstance(m, torch.nn.GELU):
- m.approximate = "tanh"
- dynamic_axes = {
- "point_coords": {1: "num_points"},
- "point_labels": {1: "num_points"},
- }
- embed_dim = sam.prompt_encoder.embed_dim
- embed_size = sam.prompt_encoder.image_embedding_size
- mask_input_size = [4 * x for x in embed_size]
- dummy_inputs = {
- "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
- "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
- "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
- "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
- "has_mask_input": torch.tensor([1], dtype=torch.float),
- "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
- }
- _ = onnx_model(**dummy_inputs)
- output_names = ["masks", "iou_predictions", "low_res_masks"]
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
- warnings.filterwarnings("ignore", category=UserWarning)
- with open(output, "wb") as f:
- print(f"Exporting onnx model to {output}...")
- torch.onnx.export(
- onnx_model,
- tuple(dummy_inputs.values()),
- f,
- export_params=True,
- verbose=False,
- opset_version=opset,
- do_constant_folding=True,
- input_names=list(dummy_inputs.keys()),
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- )
- if onnxruntime_exists:
- ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
- # set cpu provider default
- providers = ["CPUExecutionProvider"]
- ort_session = onnxruntime.InferenceSession(output, providers=providers)
- _ = ort_session.run(None, ort_inputs)
- print("Model has successfully been run with ONNXRuntime.")
- def to_numpy(tensor):
- return tensor.cpu().numpy()
- if __name__ == "__main__":
- args = parser.parse_args()
- run_export(
- model_type=args.model_type,
- checkpoint=args.checkpoint,
- output=args.output,
- opset=args.opset,
- return_single_mask=args.return_single_mask,
- gelu_approximate=args.gelu_approximate,
- use_stability_score=args.use_stability_score,
- return_extra_metrics=args.return_extra_metrics,
- )
- if args.quantize_out is not None:
- assert onnxruntime_exists, "onnxruntime is required to quantize the model."
- from onnxruntime.quantization import QuantType # type: ignore
- from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
- print(f"Quantizing model and writing to {args.quantize_out}...")
- quantize_dynamic(
- model_input=args.output,
- model_output=args.quantize_out,
- optimize_model=True,
- per_channel=False,
- reduce_range=False,
- weight_type=QuantType.QUInt8,
- )
- print("Done!")
|