diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 0308361d..cbc5eeb8 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -190,9 +190,8 @@ def test_loca_visual_prompt_counting() -> None: def test_countgd_counting() -> None: img = ski.data.coins() - result = countgd_counting(image=img, prompt="coin") - assert result["count"] == 24 + assert len(result) == 24 def test_countgd_example_based_counting() -> None: @@ -201,7 +200,7 @@ def test_countgd_example_based_counting() -> None: visual_prompts=[[85, 106, 122, 145]], image=img, ) - assert result["count"] == 24 + assert len(result) == 24 def test_git_vqa_v2() -> None: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 08bf0370..1ad6ea11 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -540,10 +540,11 @@ def countgd_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_counting"} - resp: List[Dict[str, Any]] = send_inference_request( + resp_data: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload ) # type: ignore - return resp["data"] + + return resp_data def countgd_example_based_counting( @@ -583,16 +584,20 @@ def countgd_example_based_counting( ] """ image_b64 = convert_to_b64(image) + visual_prompts = [ + denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts + ] payload = { "image": image_b64, "visual_prompts": visual_prompts, "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_example_based_counting"} - resp: List[Dict[str, Any]] = send_inference_request( + resp_data: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload ) # type: ignore - return resp["data"] + + return resp_data def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: