Skip to content

Commit 0d4526f

Browse files
committed
onnx
1 parent 759948f commit 0d4526f

File tree

3 files changed

+259
-19
lines changed

3 files changed

+259
-19
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ We propose HQ-SAM to upgrade SAM for high-quality zero-shot segmentation. Refer
88

99
Updates
1010
-----------------
11-
:fire::fire: Play with HQ-SAM demo at [![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sam-hq-team/sam-hq), supported by [Huggingface Spaces](https://huggingface.co/spaces), which supports point, box and text prompts.
11+
:fire::fire: We release the [ONNX export script](#onnx-export) and [colab notebook](https://colab.research.google.com/drive/11U2La49c2IxahzJkAV-EzPqEH3cz_5hq?usp=sharing) for exporting and using ONNX model.
12+
13+
:fire: Play with HQ-SAM demo at [![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sam-hq-team/sam-hq), supported by [Huggingface Spaces](https://huggingface.co/spaces), which supports point, box and text prompts.
1214

1315
:fire: We released the [colab notebook demo](https://colab.research.google.com/drive/1QwAbn5hsdqKOD5niuBzuqQX4eLCbNKFL?usp=sharing) and [automatic mask generator notebook](https://colab.research.google.com/drive/1dhRq4eR6Fbl-yl1vbQvU9hqyyeOidQaU?usp=sharing).
1416

@@ -133,6 +135,14 @@ python demo/demo_sam.py
133135
```
134136

135137

138+
### **ONNX export**
139+
HQ-SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime. Export the model with
140+
```
141+
python scripts/export_onnx_model.py --checkpoint <path/to/checkpoint> --model-type <model_type> --output <path/to/output>
142+
```
143+
See the [example notebook](https://colab.research.google.com/drive/11U2La49c2IxahzJkAV-EzPqEH3cz_5hq?usp=sharing) for details on how to combine image preprocessing via HQ-SAM's backbone with mask prediction using the ONNX model. It is recommended to use the latest stable version of PyTorch for ONNX export.
144+
145+
136146
Citation
137147
---------------
138148
If you find HQ-SAM useful in your research or refer to the provided baseline results, please star :star: this repository and consider citing :pencil::

scripts/export_onnx_model.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from segment_anything import sam_model_registry
10+
from segment_anything.utils.onnx import SamOnnxModel
11+
12+
import argparse
13+
import warnings
14+
15+
try:
16+
import onnxruntime # type: ignore
17+
18+
onnxruntime_exists = True
19+
except ImportError:
20+
onnxruntime_exists = False
21+
22+
parser = argparse.ArgumentParser(
23+
description="Export the SAM prompt encoder and mask decoder to an ONNX model."
24+
)
25+
26+
parser.add_argument(
27+
"--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
28+
)
29+
30+
parser.add_argument(
31+
"--output", type=str, required=True, help="The filename to save the ONNX model to."
32+
)
33+
34+
parser.add_argument(
35+
"--model-type",
36+
type=str,
37+
required=True,
38+
help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
39+
)
40+
41+
parser.add_argument(
42+
"--hq-token-only",
43+
action="store_true",
44+
help=(
45+
"False means use hq output to correct SAM output. True means use hq output only. Default: False "
46+
"To achieve best visualization effect, for images contain multiple objects (like typical coco images),"
47+
"We suggest to set hq_token_only=False. For images contain single object, we suggest to set hq_token_only = True"
48+
"For quantiative evaluation on COCO/YTVOS/DAVIS/UVO/LVIS etc., we set hq_token_only = False."
49+
),
50+
)
51+
52+
53+
parser.add_argument(
54+
"--multimask-output",
55+
action="store_true",
56+
help=(
57+
"If true, the exported ONNX model will use multi-mask output mode and "
58+
"select the best mask in multi-mask"
59+
),
60+
)
61+
62+
parser.add_argument(
63+
"--opset",
64+
type=int,
65+
default=17,
66+
help="The ONNX opset version to use. Must be >=11",
67+
)
68+
69+
parser.add_argument(
70+
"--quantize-out",
71+
type=str,
72+
default=None,
73+
help=(
74+
"If set, will quantize the model and save it with this name. "
75+
"Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
76+
),
77+
)
78+
79+
parser.add_argument(
80+
"--gelu-approximate",
81+
action="store_true",
82+
help=(
83+
"Replace GELU operations with approximations using tanh. Useful "
84+
"for some runtimes that have slow or unimplemented erf ops, used in GELU."
85+
),
86+
)
87+
88+
parser.add_argument(
89+
"--use-stability-score",
90+
action="store_true",
91+
help=(
92+
"Replaces the model's predicted mask quality score with the stability "
93+
"score calculated on the low resolution masks using an offset of 1.0. "
94+
),
95+
)
96+
97+
parser.add_argument(
98+
"--return-extra-metrics",
99+
action="store_true",
100+
help=(
101+
"The model will return five results: (masks, scores, stability_scores, "
102+
"areas, low_res_logits) instead of the usual three. This can be "
103+
"significantly slower for high resolution outputs."
104+
),
105+
)
106+
107+
108+
def run_export(
109+
model_type: str,
110+
checkpoint: str,
111+
output: str,
112+
opset: int,
113+
hq_token_only: bool = False,
114+
multimask_output: bool = False,
115+
gelu_approximate: bool = False,
116+
use_stability_score: bool = False,
117+
return_extra_metrics=False,
118+
):
119+
print("Loading model...")
120+
sam = sam_model_registry[model_type](checkpoint=checkpoint)
121+
122+
onnx_model = SamOnnxModel(
123+
model=sam,
124+
hq_token_only=hq_token_only,
125+
multimask_output=multimask_output,
126+
use_stability_score=use_stability_score,
127+
return_extra_metrics=return_extra_metrics,
128+
)
129+
130+
if gelu_approximate:
131+
for n, m in onnx_model.named_modules():
132+
if isinstance(m, torch.nn.GELU):
133+
m.approximate = "tanh"
134+
135+
dynamic_axes = {
136+
"point_coords": {1: "num_points"},
137+
"point_labels": {1: "num_points"},
138+
}
139+
140+
embed_dim = sam.prompt_encoder.embed_dim
141+
embed_size = sam.prompt_encoder.image_embedding_size
142+
encoder_embed_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280}
143+
encoder_embed_dim = encoder_embed_dim_dict[model_type]
144+
145+
mask_input_size = [4 * x for x in embed_size]
146+
dummy_inputs = {
147+
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
148+
"interm_embeddings": torch.randn(4, 1, *embed_size, encoder_embed_dim, dtype=torch.float),
149+
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
150+
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
151+
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
152+
"has_mask_input": torch.tensor([1], dtype=torch.float),
153+
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
154+
}
155+
156+
_ = onnx_model(**dummy_inputs)
157+
158+
output_names = ["masks", "iou_predictions", "low_res_masks"]
159+
160+
with warnings.catch_warnings():
161+
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
162+
warnings.filterwarnings("ignore", category=UserWarning)
163+
with open(output, "wb") as f:
164+
print(f"Exporting onnx model to {output}...")
165+
torch.onnx.export(
166+
onnx_model,
167+
tuple(dummy_inputs.values()),
168+
f,
169+
export_params=True,
170+
verbose=False,
171+
opset_version=opset,
172+
do_constant_folding=True,
173+
input_names=list(dummy_inputs.keys()),
174+
output_names=output_names,
175+
dynamic_axes=dynamic_axes,
176+
)
177+
178+
if onnxruntime_exists:
179+
ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
180+
# set cpu provider default
181+
providers = ["CPUExecutionProvider"]
182+
ort_session = onnxruntime.InferenceSession(output, providers=providers)
183+
_ = ort_session.run(None, ort_inputs)
184+
print("Model has successfully been run with ONNXRuntime.")
185+
186+
187+
def to_numpy(tensor):
188+
return tensor.cpu().numpy()
189+
190+
191+
if __name__ == "__main__":
192+
args = parser.parse_args()
193+
run_export(
194+
model_type=args.model_type,
195+
checkpoint=args.checkpoint,
196+
output=args.output,
197+
opset=args.opset,
198+
hq_token_only=args.hq_token_only,
199+
multimask_output=args.multimask_output,
200+
gelu_approximate=args.gelu_approximate,
201+
use_stability_score=args.use_stability_score,
202+
return_extra_metrics=args.return_extra_metrics,
203+
)
204+
205+
if args.quantize_out is not None:
206+
assert onnxruntime_exists, "onnxruntime is required to quantize the model."
207+
from onnxruntime.quantization import QuantType # type: ignore
208+
from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
209+
210+
print(f"Quantizing model and writing to {args.quantize_out}...")
211+
quantize_dynamic(
212+
model_input=args.output,
213+
model_output=args.quantize_out,
214+
optimize_model=True,
215+
per_channel=False,
216+
reduce_range=False,
217+
weight_type=QuantType.QUInt8,
218+
)
219+
print("Done!")

segment_anything/utils/onnx.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ class SamOnnxModel(nn.Module):
2525
def __init__(
2626
self,
2727
model: Sam,
28-
return_single_mask: bool,
28+
hq_token_only: bool = False,
29+
multimask_output: bool = False,
2930
use_stability_score: bool = False,
3031
return_extra_metrics: bool = False,
3132
) -> None:
3233
super().__init__()
3334
self.mask_decoder = model.mask_decoder
3435
self.model = model
3536
self.img_size = model.image_encoder.img_size
36-
self.return_single_mask = return_single_mask
37+
self.hq_token_only = hq_token_only
38+
self.multimask_output = multimask_output
3739
self.use_stability_score = use_stability_score
3840
self.stability_score_offset = 1.0
3941
self.return_extra_metrics = return_extra_metrics
@@ -89,25 +91,12 @@ def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -
8991
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
9092
return masks
9193

92-
def select_masks(
93-
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94-
) -> Tuple[torch.Tensor, torch.Tensor]:
95-
# Determine if we should return the multiclick mask or not from the number of points.
96-
# The reweighting is used to avoid control flow.
97-
score_reweight = torch.tensor(
98-
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99-
).to(iou_preds.device)
100-
score = iou_preds + (num_points - 2.5) * score_reweight
101-
best_idx = torch.argmax(score, dim=1)
102-
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103-
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104-
105-
return masks, iou_preds
10694

10795
@torch.no_grad()
10896
def forward(
10997
self,
11098
image_embeddings: torch.Tensor,
99+
interm_embeddings: torch.Tensor,
111100
point_coords: torch.Tensor,
112101
point_labels: torch.Tensor,
113102
mask_input: torch.Tensor,
@@ -117,20 +106,42 @@ def forward(
117106
sparse_embedding = self._embed_points(point_coords, point_labels)
118107
dense_embedding = self._embed_masks(mask_input, has_mask_input)
119108

109+
vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
110+
hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
111+
120112
masks, scores = self.model.mask_decoder.predict_masks(
121113
image_embeddings=image_embeddings,
122114
image_pe=self.model.prompt_encoder.get_dense_pe(),
123115
sparse_prompt_embeddings=sparse_embedding,
124116
dense_prompt_embeddings=dense_embedding,
117+
hq_features=hq_features,
125118
)
126119

127120
if self.use_stability_score:
128121
scores = calculate_stability_score(
129122
masks, self.model.mask_threshold, self.stability_score_offset
130123
)
131124

132-
if self.return_single_mask:
133-
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
125+
if self.multimask_output:
126+
# mask with highest score
127+
mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
128+
scores = scores[:, mask_slice]
129+
scores, max_iou_idx = torch.max(scores,dim=1)
130+
scores = scores.unsqueeze(1)
131+
masks_multi = masks[:, mask_slice, :, :]
132+
masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
133+
else:
134+
# singale mask output, default
135+
mask_slice = slice(0, 1)
136+
scores = scores[:,mask_slice]
137+
masks_sam = masks[:,mask_slice]
138+
139+
masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
140+
141+
if self.hq_token_only:
142+
masks = masks_hq
143+
else:
144+
masks = masks_sam + masks_hq
134145

135146
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136147

0 commit comments

Comments
 (0)