From 6f0dd099362c14baeee523441ae330c7c41a58fd Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 9 Aug 2024 20:47:47 -0700 Subject: [PATCH] fixed integ tests --- tests/integ/test_tools.py | 2 +- vision_agent/tools/tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 57d536bd..c1a779fd 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -118,7 +118,7 @@ def test_nsfw_classification(): result = vit_nsfw_classification( image=img, ) - assert result["labels"] == "normal" + assert result["label"] == "normal" def test_image_caption() -> None: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index d299a065..4e0fdf27 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -537,7 +537,7 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: Example ------- >>> vit_nsfw_classification(image) - {"labels": "normal", "scores": 0.68}, + {"label": "normal", "scores": 0.68}, """ image_b64 = convert_to_b64(image) @@ -603,7 +603,7 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> } answer = send_inference_request(data, "florence2", v2=True) - return answer["text"][0] # type: ignore + return answer[task] # type: ignore def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: