Skip to content

Commit 18c3fbc

Browse files
Add text to segmentation demo code (#451)
Co-authored-by: yizhuoz004 <[email protected]>
1 parent ae56948 commit 18c3fbc

File tree

2 files changed

+186
-4
lines changed

2 files changed

+186
-4
lines changed

tripy/examples/segment-anything-model-v2/sam2/sam2_image_predictor.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def predict_batch(
120120
self,
121121
point_coords_batch: List[np.ndarray] = None,
122122
point_labels_batch: List[np.ndarray] = None,
123+
box_batch: List[np.ndarray] = None,
123124
multimask_output: bool = True,
124125
return_logits: bool = False,
125126
normalize_coords=True,
@@ -164,17 +165,19 @@ def concat_batch(x):
164165

165166
point_coords = concat_batch(point_coords_batch)
166167
point_labels = concat_batch(point_labels_batch)
168+
box = concat_batch(box_batch)
167169

168-
_, unnorm_coords, labels, _ = self._prep_prompts(
170+
_, unnorm_coords, labels, unnorm_box = self._prep_prompts(
169171
point_coords,
170172
point_labels,
171-
None, # box
173+
box, # box
172174
None, # mask_input
173175
normalize_coords,
174176
)
175177
masks, iou_predictions, low_res_masks = self._predict(
176178
unnorm_coords,
177179
labels,
180+
unnorm_box,
178181
multimask_output,
179182
return_logits=return_logits,
180183
)
@@ -220,6 +223,7 @@ def _predict(
220223
self,
221224
point_coords: torch.Tensor,
222225
point_labels: torch.Tensor,
226+
boxes: Optional[torch.Tensor] = None,
223227
multimask_output: bool = True,
224228
return_logits: bool = False,
225229
img_idx: int = -1,
@@ -256,9 +260,28 @@ def _predict(
256260
if not self._is_image_set:
257261
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
258262

263+
if point_coords is not None:
264+
concat_points = (point_coords, point_labels)
265+
else:
266+
concat_points = None
267+
268+
# Embed prompts
269+
if boxes is not None:
270+
box_coords = boxes.reshape(-1, 2, 2)
271+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
272+
box_labels = box_labels.repeat(boxes.size(0), 1)
273+
# we merge "boxes" and "points" into a single "concat_points" input (where
274+
# boxes are added at the beginning) to sam_prompt_encoder
275+
if concat_points is not None:
276+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
277+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
278+
concat_points = (concat_coords, concat_labels)
279+
else:
280+
concat_points = (box_coords, box_labels)
281+
259282
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
260-
points_x=tp.Tensor(point_coords.contiguous()),
261-
points_y=tp.Tensor(point_labels.contiguous()),
283+
points_x=tp.Tensor(concat_points[0].contiguous()),
284+
points_y=tp.Tensor(concat_points[1].contiguous()),
262285
)
263286

264287
# Predict masks
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import gc
16+
import os
17+
import cv2
18+
import torch
19+
from typing import Optional
20+
import numpy as np
21+
import supervision as sv
22+
from PIL import Image
23+
from sam2.build_sam import build_sam2_video_predictor, build_sam2
24+
from sam2.sam2_image_predictor import SAM2ImagePredictor
25+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
26+
27+
28+
def main(video_dir: str, text: str, save_path: Optional[str] = None):
29+
"""
30+
Main execution function.
31+
32+
Args:
33+
video_path (str): Path to where video frames are stored
34+
save_path (str, optional): Directory to save visualizations
35+
36+
Returns:
37+
Dict[str, np.ndarray]: Processing results
38+
"""
39+
40+
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
41+
model_cfg = "sam2_hiera_l.yaml"
42+
43+
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=torch.device("cuda"))
44+
45+
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
46+
image_predictor = SAM2ImagePredictor(sam2_image_model)
47+
48+
model_id = "IDEA-Research/grounding-dino-tiny"
49+
device = "cuda" if torch.cuda.is_available() else "cpu"
50+
processor = AutoProcessor.from_pretrained(model_id)
51+
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
52+
53+
# scan all the JPEG frame names in this directory
54+
frame_names = [p for p in os.listdir(video_dir) if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]]
55+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
56+
57+
# init video predictor state
58+
inference_state = video_predictor.init_state(video_path=video_dir)
59+
60+
ann_frame_idx = 0 # the frame index we interact with
61+
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
62+
63+
"""
64+
Prompt Grounding DINO and SAM image predictor to get the box and mask
65+
"""
66+
67+
# prompt grounding dino to get the box coordinates on specific frame
68+
img_path = os.path.join(video_dir, frame_names[ann_frame_idx])
69+
image = Image.open(img_path)
70+
71+
# run Grounding DINO on the image
72+
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
73+
with torch.no_grad():
74+
outputs = grounding_model(**inputs)
75+
76+
results = processor.post_process_grounded_object_detection(
77+
outputs, inputs.input_ids, box_threshold=0.25, text_threshold=0.3, target_sizes=[image.size[::-1]]
78+
)
79+
80+
# prompt SAM image predictor to get the mask for the object
81+
image_predictor.set_image_batch([np.array(image.convert("RGB"))])
82+
83+
# process the detection results
84+
input_boxes = results[0]["boxes"]
85+
OBJECTS = results[0]["labels"]
86+
87+
# prompt SAM 2 image predictor to get the mask for the object
88+
masks, scores, logits = image_predictor._predict(
89+
point_coords=None,
90+
point_labels=None,
91+
boxes=input_boxes,
92+
multimask_output=True,
93+
)
94+
95+
# convert the mask shape to (n, H, W)
96+
if masks.ndim == 3:
97+
masks = masks[None]
98+
scores = scores[None]
99+
logits = logits[None]
100+
elif masks.ndim == 4:
101+
masks = masks[:, 0, :, :]
102+
103+
"""
104+
Register each object's positive points to video predictor
105+
"""
106+
input_boxes = input_boxes.cpu().numpy()
107+
for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
108+
_, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
109+
inference_state=inference_state,
110+
frame_idx=ann_frame_idx,
111+
obj_id=object_id,
112+
box=box,
113+
)
114+
115+
"""
116+
Propagate the video predictor to get the segmentation results for each frame
117+
"""
118+
torch.cuda.empty_cache()
119+
gc.collect()
120+
121+
video_segments = {} # video_segments contains the per-frame segmentation results
122+
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
123+
video_segments[out_frame_idx] = {
124+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids)
125+
}
126+
127+
"""
128+
Visualize the segment results across the video and save them
129+
"""
130+
131+
if not os.path.exists(save_path):
132+
os.makedirs(save_path)
133+
134+
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}
135+
for frame_idx, segments in video_segments.items():
136+
img = cv2.imread(os.path.join(video_dir, frame_names[frame_idx]))
137+
138+
object_ids = list(segments.keys())
139+
masks = list(segments.values())
140+
masks = np.concatenate(masks, axis=0)
141+
142+
detections = sv.Detections(
143+
xyxy=sv.mask_to_xyxy(masks), # (n, 4)
144+
mask=masks, # (n, h, w)
145+
class_id=np.array(object_ids, dtype=np.int32),
146+
)
147+
box_annotator = sv.BoxAnnotator()
148+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
149+
label_annotator = sv.LabelAnnotator()
150+
annotated_frame = label_annotator.annotate(
151+
annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids]
152+
)
153+
mask_annotator = sv.MaskAnnotator()
154+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
155+
cv2.imwrite(os.path.join(save_path, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)
156+
157+
158+
if __name__ == "__main__":
159+
main("./bedroom", "boy.girl.", save_path="output")

0 commit comments

Comments
 (0)