diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 57d536bd..c1a779fd 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -118,7 +118,7 @@ def test_nsfw_classification(): result = vit_nsfw_classification( image=img, ) - assert result["labels"] == "normal" + assert result["label"] == "normal" def test_image_caption() -> None: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index d299a065..4e0fdf27 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -537,7 +537,7 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: Example ------- >>> vit_nsfw_classification(image) - {"labels": "normal", "scores": 0.68}, + {"label": "normal", "scores": 0.68}, """ image_b64 = convert_to_b64(image) @@ -603,7 +603,7 @@ def florencev2_image_caption(image: np.ndarray, detail_caption: bool = True) -> } answer = send_inference_request(data, "florence2", v2=True) - return answer["text"][0] # type: ignore + return answer[task] # type: ignore def florencev2_object_detection(image: np.ndarray) -> List[Dict[str, Any]]: