diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index eae02205..0308361d 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -198,7 +198,7 @@ def test_countgd_counting() -> None: def test_countgd_example_based_counting() -> None: img = ski.data.coins() result = countgd_example_based_counting( - visual_prompt=[[85, 106, 122, 145]], + visual_prompts=[[85, 106, 122, 145]], image=img, ) assert result["count"] == 24 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3478e0a8..cfc43534 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -541,7 +541,7 @@ def countgd_counting( "box_threshold": box_threshold, "function_name": "countgd_counting", } - data: Dict[str, Any] = send_inference_request( + data: List[Dict[str, Any]] = send_inference_request( payload, "countgd_counting", files=files, v2=True ) return data @@ -588,7 +588,7 @@ def countgd_example_based_counting( "box_threshold": box_threshold, "function_name": "countgd_example_based_counting", } - data: Dict[str, Any] = send_inference_request( + data: List[Dict[str, Any]] = send_inference_request( payload, "countgd_example_based_counting", files=files, v2=True ) return data