From 3a8d3c4ea400802ef9ce1f13c02b00cf53e56be3 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Thu, 16 May 2024 16:27:46 -0700 Subject: [PATCH] added test case for every tool --- tests/test_tools.py | 37 +++++++++++++----------- vision_agent/tools/tools_v2.py | 47 +++++++++++++++++++++++++------ vision_agent/utils/image_utils.py | 9 ++++-- 3 files changed, 66 insertions(+), 27 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 5012385b..78ba86fe 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,19 +1,22 @@ import skimage as ski from PIL import Image -from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption + from vision_agent.tools.tools_v2 import ( clip, zero_shot_counting, visual_prompt_counting, image_question_answering, ocr, + grounding_dino, + grounding_sam, + image_caption, ) def test_grounding_dino(): img = Image.fromarray(ski.data.coins()) - result = GroundingDINO()( + result = grounding_dino( prompt="coin", image=img, ) @@ -24,7 +27,7 @@ def test_grounding_dino(): def test_grounding_sam(): img = Image.fromarray(ski.data.coins()) - result = GroundingSAM()( + result = grounding_sam( prompt="coin", image=img, ) @@ -40,44 +43,46 @@ def test_clip(): classes=["coins", "notes"], image=img, ) - assert result["scores"] == [0.99, 0.01] + assert result["scores"] == [0.9999, 0.0001] def test_image_caption() -> None: - img = Image.fromarray(ski.data.coins()) - result = ImageCaption()(image=img) - assert result["text"] + img = ski.data.rocket() + result = image_caption( + image=img, + ) + assert result.strip() == "a rocket on a stand" def test_zero_shot_counting() -> None: - img = Image.fromarray(ski.data.coins()) + img = ski.data.coins() result = zero_shot_counting( image=img, ) - assert result["count"] == 24 + assert result["count"] == 21 def test_visual_prompt_counting() -> None: - img = Image.fromarray(ski.data.checkerboard()) + img = ski.data.coins() result = visual_prompt_counting( - visual_prompt={"bbox": [0.125, 0, 0.25, 0.125]}, + visual_prompt={"bbox": [85, 106, 122, 145]}, image=img, ) - assert result["count"] == 32 + assert result["count"] == 25 def test_image_question_answering() -> None: - img = Image.fromarray(ski.data.rocket()) + img = ski.data.rocket() result = image_question_answering( prompt="Is the scene captured during day or night ?", image=img, ) - assert result == "night" + assert result.strip() == "night" def test_ocr() -> None: - img = Image.fromarray(ski.data.page()) + img = ski.data.page() result = ocr( image=img, ) - assert result[0]["label"] == "Region-based segmentation" + assert any("Region-based segmentation" in res["label"] for res in result) diff --git a/vision_agent/tools/tools_v2.py b/vision_agent/tools/tools_v2.py index ada93de2..4d114e49 100644 --- a/vision_agent/tools/tools_v2.py +++ b/vision_agent/tools/tools_v2.py @@ -56,7 +56,7 @@ def grounding_dino( prompt: str, image: np.ndarray, box_threshold: float = 0.20, - iou_threshold: float = 0.75, + iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: """'grounding_dino' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of @@ -84,7 +84,7 @@ def grounding_dino( ] """ image_size = image.shape[:2] - image_b64 = convert_to_b64(Image.fromarray(image)) + image_b64 = convert_to_b64(image) request_data = { "prompt": prompt, "image": image_b64, @@ -108,7 +108,7 @@ def grounding_sam( prompt: str, image: np.ndarray, box_threshold: float = 0.20, - iou_threshold: float = 0.75, + iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: """'grounding_sam' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of @@ -144,7 +144,7 @@ def grounding_sam( ] """ image_size = image.shape[:2] - image_b64 = convert_to_b64(Image.fromarray(image)) + image_b64 = convert_to_b64(image) request_data = { "prompt": prompt, "image": image_b64, @@ -242,7 +242,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]: return output -def zero_shot_counting(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: +def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: """'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content. It returns only the count of the objects in the image. @@ -305,7 +305,7 @@ def visual_prompt_counting( def image_question_answering(image: np.ndarray, prompt: str) -> str: """'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image. - It returns a text describing the image and the answer to the question + It returns an answer to the question Parameters: image (np.ndarray): The reference image used for the question @@ -317,7 +317,7 @@ def image_question_answering(image: np.ndarray, prompt: str) -> str: Example ------- >>> image_question_answering(image, 'What is the cat doing ?') - 'This image contains a cat sitting on a table with a bowl of milk.' + 'drinking milk' """ @@ -328,7 +328,8 @@ def image_question_answering(image: np.ndarray, prompt: str) -> str: "tool": "image_question_answering", } - return _send_inference_request(data, "tools")["text"] + answer = _send_inference_request(data, "tools") + return answer["text"][0] def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: @@ -351,7 +352,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: image_b64 = convert_to_b64(image) data = { - "prompt": classes, + "prompt": ",".join(classes), "image": image_b64, "tool": "closed_set_image_classification", } @@ -360,6 +361,33 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: return resp_data +def image_caption(image: np.ndarray) -> str: + """'image_caption' is a tool that can caption an image based on its contents. + It returns a text describing the image. + + Parameters: + image (np.ndarray): The image to caption + + Returns: + str: A string which is the caption for the given image. + + Example + ------- + >>> image_caption(image) + 'This image contains a cat sitting on a table with a bowl of milk.' + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "image_captioning", + } + + answer = _send_inference_request(data, "tools") + return answer["text"][0] + + def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float: """'closest_mask_distance' calculates the closest distance between two masks. @@ -633,6 +661,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: zero_shot_counting, visual_prompt_counting, image_question_answering, + image_caption, closest_mask_distance, closest_box_distance, save_json, diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index 1fb68a3f..5d638c62 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -104,15 +104,20 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: """ if data is None: raise ValueError(f"Invalid input image: {data}. Input image can't be None.") + if isinstance(data, (str, Path)): data = Image.open(data) + elif isinstance(data, np.ndarray): + data = Image.fromarray(data) + if isinstance(data, Image.Image): buffer = BytesIO() data.convert("RGB").save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") else: - arr_bytes = data.tobytes() - return base64.b64encode(arr_bytes).decode("utf-8") + raise ValueError( + f"Invalid input image: {data}. Input image must be a PIL Image or a numpy array." + ) def denormalize_bbox(