Skip to content

Commit 65abbd3

Browse files
fixing test cases for grounding tools as the output format changed
1 parent d9dd6e7 commit 65abbd3

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

tests/test_tools.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ def test_grounding_dino():
1818
prompt="coin",
1919
image=img,
2020
)
21-
assert result["labels"] == ["coin"] * 24
22-
assert len(result["bboxes"]) == 24
23-
assert len(result["scores"]) == 24
21+
assert len(result) == 24
22+
assert [res["label"] for res in result] == ["coin"] * 24
2423

2524

2625
def test_grounding_sam():
@@ -29,10 +28,9 @@ def test_grounding_sam():
2928
prompt="coin",
3029
image=img,
3130
)
32-
assert result["labels"] == ["coin"] * 24
33-
assert len(result["bboxes"]) == 24
34-
assert len(result["scores"]) == 24
35-
assert len(result["masks"]) == 24
31+
assert len(result) == 24
32+
assert [res["label"] for res in result] == ["coin"] * 24
33+
assert len([res["mask"] for res in result]) == 24
3634

3735

3836
def test_clip():

vision_agent/tools/tools_v2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def grounding_dino(
6868
box_threshold (float, optional): The threshold for the box detection. Defaults
6969
to 0.20.
7070
iou_threshold (float, optional): The threshold for the Intersection over Union
71-
(IoU). Defaults to 0.75.
71+
(IoU). Defaults to 0.20.
7272
7373
Returns:
7474
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
7575
bounding box of the detected objects with normalized coordinates
76-
(x1, y1, x2, y2).
76+
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
77+
xmax and ymax are the coordinates of the bottom-right of the bounding box.
7778
7879
Example
7980
-------
@@ -120,12 +121,15 @@ def grounding_sam(
120121
box_threshold (float, optional): The threshold for the box detection. Defaults
121122
to 0.20.
122123
iou_threshold (float, optional): The threshold for the Intersection over Union
123-
(IoU). Defaults to 0.75.
124+
(IoU). Defaults to 0.20.
124125
125126
Returns:
126127
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
127128
bounding box, and mask of the detected objects with normalized coordinates
128-
(x1, y1, x2, y2).
129+
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
130+
xmax and ymax are the coordinates of the bottom-right of the bounding box.
131+
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
132+
the background.
129133
130134
Example
131135
-------

0 commit comments

Comments
 (0)