Skip to content

Commit e24762e

Browse files
committed
added florence2+sam2 for video
1 parent 9ce9ec8 commit e24762e

File tree

3 files changed

+158
-28
lines changed

3 files changed

+158
-28
lines changed

vision_agent/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
florence2_ocr,
2222
florence2_roberta_vqa,
2323
florence2_sam2_image,
24+
florence2_sam2_video,
2425
generate_pose_image,
2526
generate_soft_edge_image,
2627
get_tool_documentation,

vision_agent/tools/tools.py

Lines changed: 137 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
convert_quad_box_to_bbox,
2828
convert_to_b64,
2929
denormalize_bbox,
30+
frames_to_bytes,
3031
get_image_size,
3132
normalize_bbox,
3233
numpy_to_bytes,
@@ -184,10 +185,10 @@ def grounding_sam(
184185
box_threshold: float = 0.20,
185186
iou_threshold: float = 0.20,
186187
) -> List[Dict[str, Any]]:
187-
"""'grounding_sam' is a tool that can segment multiple objects given a
188-
text prompt such as category names or referring expressions. The categories in text
189-
prompt are separated by commas or periods. It returns a list of bounding boxes,
190-
label names, mask file names and associated probability scores.
188+
"""'grounding_sam' is a tool that can segment multiple objects given a text prompt
189+
such as category names or referring expressions. The categories in text prompt are
190+
separated by commas or periods. It returns a list of bounding boxes, label names,
191+
mask file names and associated probability scores.
191192
192193
Parameters:
193194
prompt (str): The prompt to ground to the image.
@@ -245,8 +246,8 @@ def grounding_sam(
245246

246247

247248
def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
248-
"""'florence2_sam2_image' is a tool that can segment multiple objects given a
249-
text prompt such as category names or referring expressions. The categories in text
249+
"""'florence2_sam2_image' is a tool that can segment multiple objects given a text
250+
prompt such as category names or referring expressions. The categories in the text
250251
prompt are separated by commas. It returns a list of bounding boxes, label names,
251252
mask file names and associated probability scores.
252253
@@ -297,6 +298,63 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
297298
return return_data
298299

299300

301+
def florence2_sam2_video(
302+
prompt: str, frames: List[np.ndarray]
303+
) -> List[List[Dict[str, Any]]]:
304+
"""'florence2_sam2_video' is a tool that can segment and track multiple objects
305+
in a video given a text prompt such as category names or referring expressions. The
306+
categories in the text prompt are separated by commas. It returns tracked objects
307+
as masks, labels, and scores for each frame.
308+
309+
Parameters:
310+
prompt (str): The prompt to ground to the video.
311+
frames (List[np.ndarray]): The list of frames to ground the prompt to.
312+
313+
Returns:
314+
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label,
315+
score and mask of the detected objects. The outer list represents each frame
316+
and the inner list is the objects per frame. The label contains the object ID
317+
followed by the label name. The objects are only identified in the first framed
318+
and tracked throughout the video.
319+
320+
Example
321+
-------
322+
>>> florence2_sam2_video("car, dinosaur", frames)
323+
[
324+
[
325+
{
326+
'label': '0: dinosaur',
327+
'score': 1.0,
328+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
329+
[0, 0, 0, ..., 0, 0, 0],
330+
...,
331+
[0, 0, 0, ..., 0, 0, 0],
332+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
333+
},
334+
],
335+
]
336+
"""
337+
338+
buffer_bytes = frames_to_bytes(frames)
339+
files = [("video", buffer_bytes)]
340+
payload = {
341+
"prompts": prompt.split(","),
342+
"function_name": "florence2_sam2_video",
343+
}
344+
data: Dict[str, Any] = send_inference_request(
345+
payload, "florence2-sam2", files=files, v2=True
346+
)
347+
return_data = []
348+
for frame_i in data.keys():
349+
return_frame_data = []
350+
for obj_id, data_j in data[frame_i].items():
351+
mask = rle_decode_array(data_j["mask"])
352+
label = obj_id + ": " + data_j["label"]
353+
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
354+
return_data.append(return_frame_data)
355+
return return_data
356+
357+
300358
def extract_frames(
301359
video_uri: Union[str, Path], fps: float = 0.5
302360
) -> List[Tuple[np.ndarray, float]]:
@@ -1274,15 +1332,43 @@ def overlay_bounding_boxes(
12741332
return np.array(pil_image)
12751333

12761334

1335+
def _get_text_coords_from_mask(
1336+
mask: np.ndarray, v_gap: int = 10, h_gap: int = 10
1337+
) -> Tuple[int, int]:
1338+
mask = mask.astype(np.uint8)
1339+
if np.sum(mask) == 0:
1340+
return (0, 0)
1341+
1342+
rows, cols = np.nonzero(mask)
1343+
top = rows.min()
1344+
bottom = rows.max()
1345+
left = cols.min()
1346+
right = cols.max()
1347+
1348+
if top - v_gap < 0:
1349+
if bottom + v_gap > mask.shape[0]:
1350+
top = top
1351+
else:
1352+
top = bottom + v_gap
1353+
else:
1354+
top = top - v_gap
1355+
1356+
return left + (right - left) // 2 - h_gap, top
1357+
1358+
12771359
def overlay_segmentation_masks(
1278-
image: np.ndarray, masks: List[Dict[str, Any]]
1279-
) -> np.ndarray:
1360+
medias: Union[np.ndarray, List[np.ndarray]],
1361+
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
1362+
draw_label: bool = True,
1363+
) -> Union[np.ndarray, List[np.ndarray]]:
12801364
"""'overlay_segmentation_masks' is a utility function that displays segmentation
12811365
masks.
12821366
12831367
Parameters:
1284-
image (np.ndarray): The image to display the masks on.
1285-
masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
1368+
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
1369+
the masks on.
1370+
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
1371+
dictionaries containing the masks.
12861372
12871373
Returns:
12881374
np.ndarray: The image with the masks displayed.
@@ -1302,27 +1388,50 @@ def overlay_segmentation_masks(
13021388
}],
13031389
)
13041390
"""
1305-
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")
1391+
medias_int: List[np.ndarray] = (
1392+
[medias] if isinstance(medias, np.ndarray) else medias
1393+
)
1394+
masks_int = [masks] if isinstance(masks[0], dict) else masks
1395+
masks_int = cast(List[List[Dict[str, Any]]], masks_int)
13061396

1307-
if len(set([mask["label"] for mask in masks])) > len(COLORS):
1308-
_LOGGER.warning(
1309-
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
1310-
)
1397+
labels = set()
1398+
for mask_i in masks_int:
1399+
for mask_j in mask_i:
1400+
labels.add(mask_j["label"])
1401+
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}
13111402

1312-
color = {
1313-
label: COLORS[i % len(COLORS)]
1314-
for i, label in enumerate(set([mask["label"] for mask in masks]))
1315-
}
1316-
masks = sorted(masks, key=lambda x: x["label"], reverse=True)
1403+
width, height = Image.fromarray(medias_int[0]).size
1404+
fontsize = max(12, int(min(width, height) / 40))
1405+
font = ImageFont.truetype(
1406+
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
1407+
fontsize,
1408+
)
13171409

1318-
for elt in masks:
1319-
mask = elt["mask"]
1320-
label = elt["label"]
1321-
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
1322-
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
1323-
mask_img = Image.fromarray(np_mask.astype(np.uint8))
1324-
pil_image = Image.alpha_composite(pil_image, mask_img)
1325-
return np.array(pil_image)
1410+
frame_out = []
1411+
for i, frame in enumerate(medias_int):
1412+
pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
1413+
for elt in masks_int[i]:
1414+
mask = elt["mask"]
1415+
label = elt["label"]
1416+
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
1417+
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
1418+
mask_img = Image.fromarray(np_mask.astype(np.uint8))
1419+
pil_image = Image.alpha_composite(pil_image, mask_img)
1420+
1421+
if draw_label:
1422+
draw = ImageDraw.Draw(pil_image)
1423+
text_box = draw.textbbox((0, 0), text=label, font=font)
1424+
x, y = _get_text_coords_from_mask(
1425+
mask,
1426+
v_gap=(text_box[3] - text_box[1]) + 10,
1427+
h_gap=(text_box[2] - text_box[0]) // 2,
1428+
)
1429+
if x != 0 and y != 0:
1430+
text_box = draw.textbbox((x, y), text=label, font=font)
1431+
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
1432+
draw.text((x, y), label, fill="black", font=font)
1433+
frame_out.append(np.array(pil_image)) # type: ignore
1434+
return frame_out[0] if len(frame_out) == 1 else frame_out
13261435

13271436

13281437
def overlay_heat_map(

vision_agent/utils/image_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
import base64
44
import io
5+
import tempfile
56
from importlib import resources
67
from io import BytesIO
78
from pathlib import Path
89
from typing import Dict, List, Tuple, Union
910

1011
import numpy as np
12+
from moviepy.editor import ImageSequenceClip
1113
from PIL import Image, ImageDraw, ImageFont
1214
from PIL.Image import Image as ImageType
1315

@@ -86,6 +88,24 @@ def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
8688
return binary_mask
8789

8890

91+
def frames_to_bytes(
92+
frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4"
93+
) -> bytes:
94+
r"""Convert a list of frames to a video file encoded into a byte string.
95+
96+
Parameters:
97+
frames: the list of frames
98+
fps: the frames per second of the video
99+
file_ext: the file extension of the video file
100+
"""
101+
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
102+
clip = ImageSequenceClip(frames, fps=fps)
103+
clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps)
104+
with open(temp_file.name + f".{file_ext}", "rb") as f:
105+
buffer_bytes = f.read()
106+
return buffer_bytes
107+
108+
89109
def b64_to_pil(b64_str: str) -> ImageType:
90110
r"""Convert a base64 string to a PIL Image.
91111

0 commit comments

Comments
 (0)