Skip to content

Commit 28ad471

Browse files
committed
added florence2+sam2 for images
1 parent cfa0ecd commit 28ad471

File tree

5 files changed

+117
-10
lines changed

5 files changed

+117
-10
lines changed

tests/integ/test_tools.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
florence2_object_detection,
1313
florence2_roberta_vqa,
1414
florence2_ocr,
15+
florence2_sam2_image,
1516
generate_pose_image,
1617
generate_soft_edge_image,
1718
git_vqa_v2,
@@ -88,6 +89,17 @@ def test_grounding_sam():
8889
assert len([res["mask"] for res in result]) == 24
8990

9091

92+
def test_florence2_sam2_image():
93+
img = ski.data.coins()
94+
result = florence2_sam2_image(
95+
prompt="coin",
96+
image=img,
97+
)
98+
assert len(result) == 25
99+
assert [res["label"] for res in result] == ["coin"] * 25
100+
assert len([res["mask"] for res in result]) == 25
101+
102+
91103
def test_segmentation():
92104
img = ski.data.coins()
93105
result = detr_segmentation(

vision_agent/tools/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
extract_frames,
1919
florence2_image_caption,
2020
florence2_object_detection,
21-
florence2_roberta_vqa,
2221
florence2_ocr,
22+
florence2_roberta_vqa,
23+
florence2_sam2_image,
2324
generate_pose_image,
2425
generate_soft_edge_image,
2526
get_tool_documentation,

vision_agent/tools/tool_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import logging
33
import os
4-
from typing import Any, Callable, Dict, List, MutableMapping, Optional
4+
from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple
55

66
import pandas as pd
77
from IPython.display import display
@@ -28,7 +28,10 @@ class ToolCallTrace(BaseModel):
2828

2929

3030
def send_inference_request(
31-
payload: Dict[str, Any], endpoint_name: str, v2: bool = False
31+
payload: Dict[str, Any],
32+
endpoint_name: str,
33+
files: Optional[List[Tuple[Any, ...]]] = None,
34+
v2: bool = False,
3235
) -> Dict[str, Any]:
3336
try:
3437
if runtime_tag := os.environ.get("RUNTIME_TAG", ""):
@@ -44,7 +47,7 @@ def send_inference_request(
4447
response={},
4548
error=None,
4649
)
47-
headers = {"Content-Type": "application/json", "apikey": _LND_API_KEY}
50+
headers = {"apikey": _LND_API_KEY}
4851
if "TOOL_ENDPOINT_AUTH" in os.environ:
4952
headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"]
5053
headers.pop("apikey")
@@ -54,7 +57,11 @@ def send_inference_request(
5457
num_retry=3,
5558
headers=headers,
5659
)
57-
res = session.post(url, json=payload)
60+
61+
if files is not None:
62+
res = session.post(url, data=payload, files=files)
63+
else:
64+
res = session.post(url, json=payload)
5865
if res.status_code != 200:
5966
tool_call_trace.error = Error(
6067
name="RemoteToolCallFailed",

vision_agent/tools/tools.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,36 @@
22
import json
33
import logging
44
import tempfile
5-
from pathlib import Path
65
from importlib import resources
6+
from pathlib import Path
77
from typing import Any, Dict, List, Optional, Tuple, Union, cast
88

99
import cv2
10-
import requests
1110
import numpy as np
12-
from pytube import YouTube # type: ignore
11+
import requests
1312
from moviepy.editor import ImageSequenceClip
1413
from PIL import Image, ImageDraw, ImageFont
1514
from pillow_heif import register_heif_opener # type: ignore
15+
from pytube import YouTube # type: ignore
1616

1717
from vision_agent.tools.tool_utils import (
18-
send_inference_request,
1918
get_tool_descriptions,
2019
get_tool_documentation,
2120
get_tools_df,
21+
send_inference_request,
2222
)
2323
from vision_agent.utils import extract_frames_from_video
2424
from vision_agent.utils.execute import FileSerializer, MimeType
2525
from vision_agent.utils.image_utils import (
2626
b64_to_pil,
27+
convert_quad_box_to_bbox,
2728
convert_to_b64,
2829
denormalize_bbox,
2930
get_image_size,
3031
normalize_bbox,
31-
convert_quad_box_to_bbox,
32+
numpy_to_bytes,
3233
rle_decode,
34+
rle_decode_array,
3335
)
3436

3537
register_heif_opener()
@@ -242,6 +244,59 @@ def grounding_sam(
242244
return return_data
243245

244246

247+
def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
248+
"""'florence2_sam2_image' is a tool that can segment multiple objects given a
249+
text prompt such as category names or referring expressions. The categories in text
250+
prompt are separated by commas. It returns a list of bounding boxes, label names,
251+
mask file names and associated probability scores.
252+
253+
Parameters:
254+
prompt (str): The prompt to ground to the image.
255+
image (np.ndarray): The image to ground the prompt to.
256+
257+
Returns:
258+
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
259+
bounding box, and mask of the detected objects with normalized coordinates
260+
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
261+
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
262+
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
263+
the background.
264+
265+
Example
266+
-------
267+
>>> florence2_sam2_image("car, dinosaur", image)
268+
[
269+
{
270+
'score': 0.99,
271+
'label': 'dinosaur',
272+
'bbox': [0.1, 0.11, 0.35, 0.4],
273+
'mask': array([[0, 0, 0, ..., 0, 0, 0],
274+
[0, 0, 0, ..., 0, 0, 0],
275+
...,
276+
[0, 0, 0, ..., 0, 0, 0],
277+
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
278+
},
279+
]
280+
"""
281+
buffer_bytes = numpy_to_bytes(image)
282+
283+
files = [("image", buffer_bytes)]
284+
payload = {
285+
"prompts": prompt.split(","),
286+
"function_name": "florence2_sam2_image",
287+
}
288+
data: Dict[str, Any] = send_inference_request(
289+
payload, "florence2-sam2", files=files, v2=True
290+
)
291+
return_data = []
292+
for _, data_i in data["0"].items():
293+
mask = rle_decode_array(data_i["mask"])
294+
label = data_i["label"]
295+
bbox = normalize_bbox(data_i["bounding_box"], data_i["mask"]["size"])
296+
return_data.append({"label": label, "bbox": bbox, "mask": mask, "score": 1.0})
297+
return return_data
298+
299+
245300
def extract_frames(
246301
video_uri: Union[str, Path], fps: float = 0.5
247302
) -> List[Tuple[np.ndarray, float]]:

vision_agent/utils/image_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utility functions for image processing."""
22

33
import base64
4+
import io
45
from importlib import resources
56
from io import BytesIO
67
from pathlib import Path
@@ -63,6 +64,28 @@ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
6364
return img.reshape(shape)
6465

6566

67+
def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
68+
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
69+
70+
Parameters:
71+
mask: The mask in run-length encoded as an array.
72+
"""
73+
size = rle["size"]
74+
counts = rle["counts"]
75+
76+
total_elements = size[0] * size[1]
77+
flattened_mask = np.zeros(total_elements, dtype=np.uint8)
78+
79+
current_pos = 0
80+
for i, count in enumerate(counts):
81+
if i % 2 == 1:
82+
flattened_mask[current_pos : current_pos + count] = 1
83+
current_pos += count
84+
85+
binary_mask = flattened_mask.reshape(size, order="F")
86+
return binary_mask
87+
88+
6689
def b64_to_pil(b64_str: str) -> ImageType:
6790
r"""Convert a base64 string to a PIL Image.
6891
@@ -78,6 +101,15 @@ def b64_to_pil(b64_str: str) -> ImageType:
78101
return Image.open(BytesIO(base64.b64decode(b64_str)))
79102

80103

104+
def numpy_to_bytes(image: np.ndarray) -> bytes:
105+
pil_image = Image.fromarray(image).convert("RGB")
106+
image_buffer = io.BytesIO()
107+
pil_image.save(image_buffer, format="PNG")
108+
buffer_bytes = image_buffer.getvalue()
109+
image_buffer.close()
110+
return buffer_bytes
111+
112+
81113
def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
82114
r"""Get the size of an image.
83115

0 commit comments

Comments
 (0)