Skip to content

Commit

Permalink
fixing test cases for grounding tools as the output format changed
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed May 17, 2024
1 parent d9dd6e7 commit 65abbd3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
12 changes: 5 additions & 7 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_grounding_dino():
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_grounding_sam():
Expand All @@ -29,10 +28,9 @@ def test_grounding_sam():
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result["masks"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24


def test_clip():
Expand Down
12 changes: 8 additions & 4 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ def grounding_dino(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.
Example
-------
Expand Down Expand Up @@ -120,12 +121,15 @@ def grounding_sam(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
Expand Down

0 comments on commit 65abbd3

Please sign in to comment.