diff --git a/tests/test_tools.py b/tests/test_tools.py index b6cb43d9..2a848a02 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -18,9 +18,8 @@ def test_grounding_dino(): prompt="coin", image=img, ) - assert result["labels"] == ["coin"] * 24 - assert len(result["bboxes"]) == 24 - assert len(result["scores"]) == 24 + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 def test_grounding_sam(): @@ -29,10 +28,9 @@ def test_grounding_sam(): prompt="coin", image=img, ) - assert result["labels"] == ["coin"] * 24 - assert len(result["bboxes"]) == 24 - assert len(result["scores"]) == 24 - assert len(result["masks"]) == 24 + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + assert len([res["mask"] for res in result]) == 24 def test_clip(): diff --git a/vision_agent/tools/tools_v2.py b/vision_agent/tools/tools_v2.py index 2b0d7784..37e76a28 100644 --- a/vision_agent/tools/tools_v2.py +++ b/vision_agent/tools/tools_v2.py @@ -68,12 +68,13 @@ def grounding_dino( box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20. iou_threshold (float, optional): The threshold for the Intersection over Union - (IoU). Defaults to 0.75. + (IoU). Defaults to 0.20. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and bounding box of the detected objects with normalized coordinates - (x1, y1, x2, y2). + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and + xmax and ymax are the coordinates of the bottom-right of the bounding box. Example ------- @@ -120,12 +121,15 @@ def grounding_sam( box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20. iou_threshold (float, optional): The threshold for the Intersection over Union - (IoU). Defaults to 0.75. + (IoU). Defaults to 0.20. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, bounding box, and mask of the detected objects with normalized coordinates - (x1, y1, x2, y2). + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and + xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. Example -------