From 21849f4617af7c61cc195c8ffe3f1adfb9069202 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-vision-eng@users.noreply.github.com> Date: Fri, 17 May 2024 11:49:30 -0700 Subject: [PATCH] Adding more cv tools to coder agent (#88) * adding more cv tools to coder agent * Bug fix: missing kernel python3 (#87) * Bug fix: missing kernel python3 * Update API key * [skip ci] chore(release): vision-agent 0.2.24 * adding more cv tools to coder agent * Added test cases for the tools added * added test case for every tool * fix linting * fixing tests * fix linting * fixing test cases for grounding tools as the output format changed --------- Co-authored-by: Asia <2736300+humpydonkey@users.noreply.github.com> Co-authored-by: GitHub Actions Bot --- tests/test_tools.py | 82 ++++++++++---- vision_agent/tools/tools.py | 2 +- vision_agent/tools/tools_v2.py | 180 ++++++++++++++++++++++++++++-- vision_agent/utils/image_utils.py | 9 +- 4 files changed, 241 insertions(+), 32 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index b35cd7e3..2a848a02 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,42 +1,84 @@ import skimage as ski -from PIL import Image -from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption +from vision_agent.tools.tools_v2 import ( + clip, + zero_shot_counting, + visual_prompt_counting, + image_question_answering, + ocr, + grounding_dino, + grounding_sam, + image_caption, +) def test_grounding_dino(): - img = Image.fromarray(ski.data.coins()) - result = GroundingDINO()( + img = ski.data.coins() + result = grounding_dino( prompt="coin", image=img, ) - assert result["labels"] == ["coin"] * 24 - assert len(result["bboxes"]) == 24 - assert len(result["scores"]) == 24 + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 def test_grounding_sam(): - img = Image.fromarray(ski.data.coins()) - result = GroundingSAM()( + img = ski.data.coins() + result = grounding_sam( prompt="coin", image=img, ) - assert result["labels"] == ["coin"] * 24 - assert len(result["bboxes"]) == 24 - assert len(result["scores"]) == 24 - assert len(result["masks"]) == 24 + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + assert len([res["mask"] for res in result]) == 24 def test_clip(): - img = Image.fromarray(ski.data.coins()) - result = CLIP()( - prompt="coins", + img = ski.data.coins() + result = clip( + classes=["coins", "notes"], image=img, ) - assert result["scores"] == [1.0] + assert result["scores"] == [0.9999, 0.0001] def test_image_caption() -> None: - img = Image.fromarray(ski.data.coins()) - result = ImageCaption()(image=img) - assert result["text"] + img = ski.data.rocket() + result = image_caption( + image=img, + ) + assert result.strip() == "a rocket on a stand" + + +def test_zero_shot_counting() -> None: + img = ski.data.coins() + result = zero_shot_counting( + image=img, + ) + assert result["count"] == 21 + + +def test_visual_prompt_counting() -> None: + img = ski.data.coins() + result = visual_prompt_counting( + visual_prompt={"bbox": [85, 106, 122, 145]}, + image=img, + ) + assert result["count"] == 25 + + +def test_image_question_answering() -> None: + img = ski.data.rocket() + result = image_question_answering( + prompt="Is the scene captured during day or night ?", + image=img, + ) + assert result.strip() == "night" + + +def test_ocr() -> None: + img = ski.data.page() + result = ocr( + image=img, + ) + assert any("Region-based segmentation" in res["label"] for res in result) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index cad7f4ad..fdbc1fe2 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -53,7 +53,7 @@ def __call__(self) -> None: class CLIP(Tool): - r"""CLIP is a tool that can classify or tag any image given a set if input classes + r"""CLIP is a tool that can classify or tag any image given a set of input classes or tags. Example diff --git a/vision_agent/tools/tools_v2.py b/vision_agent/tools/tools_v2.py index 1f1a3c6d..37e76a28 100644 --- a/vision_agent/tools/tools_v2.py +++ b/vision_agent/tools/tools_v2.py @@ -15,7 +15,14 @@ from vision_agent.tools.tool_utils import _send_inference_request from vision_agent.utils import extract_frames_from_video -from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode +from vision_agent.utils.image_utils import ( + convert_to_b64, + normalize_bbox, + rle_decode, + b64_to_pil, + get_image_size, + denormalize_bbox, +) COLORS = [ (158, 218, 229), @@ -49,7 +56,7 @@ def grounding_dino( prompt: str, image: np.ndarray, box_threshold: float = 0.20, - iou_threshold: float = 0.75, + iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: """'grounding_dino' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of @@ -61,12 +68,13 @@ def grounding_dino( box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20. iou_threshold (float, optional): The threshold for the Intersection over Union - (IoU). Defaults to 0.75. + (IoU). Defaults to 0.20. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, and bounding box of the detected objects with normalized coordinates - (x1, y1, x2, y2). + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and + xmax and ymax are the coordinates of the bottom-right of the bounding box. Example ------- @@ -77,7 +85,7 @@ def grounding_dino( ] """ image_size = image.shape[:2] - image_b64 = convert_to_b64(Image.fromarray(image)) + image_b64 = convert_to_b64(image) request_data = { "prompt": prompt, "image": image_b64, @@ -101,7 +109,7 @@ def grounding_sam( prompt: str, image: np.ndarray, box_threshold: float = 0.20, - iou_threshold: float = 0.75, + iou_threshold: float = 0.20, ) -> List[Dict[str, Any]]: """'grounding_sam' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of @@ -113,12 +121,15 @@ def grounding_sam( box_threshold (float, optional): The threshold for the box detection. Defaults to 0.20. iou_threshold (float, optional): The threshold for the Intersection over Union - (IoU). Defaults to 0.75. + (IoU). Defaults to 0.20. Returns: List[Dict[str, Any]]: A list of dictionaries containing the score, label, bounding box, and mask of the detected objects with normalized coordinates - (x1, y1, x2, y2). + (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and + xmax and ymax are the coordinates of the bottom-right of the bounding box. + The mask is binary 2D numpy array where 1 indicates the object and 0 indicates + the background. Example ------- @@ -137,7 +148,7 @@ def grounding_sam( ] """ image_size = image.shape[:2] - image_b64 = convert_to_b64(Image.fromarray(image)) + image_b64 = convert_to_b64(image) request_data = { "prompt": prompt, "image": image_b64, @@ -235,6 +246,152 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]: return output +def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: + """'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content. + It returns only the count of the objects in the image. + + Parameters: + image (np.ndarray): The image that contains lot of instances of a single object + + Returns: + Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}. + + Example + ------- + >>> zero_shot_counting(image) + {'count': 45}, + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "zero_shot_counting", + } + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data + + +def visual_prompt_counting( + image: np.ndarray, visual_prompt: Dict[str, List[float]] +) -> Dict[str, Any]: + """'visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object. + It returns only the count of the objects in the image. + + Parameters: + image (np.ndarray): The image that contains lot of instances of a single object + + Returns: + Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}. + + Example + ------- + >>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]}) + {'count': 45}, + + """ + + image_size = get_image_size(image) + bbox = visual_prompt["bbox"] + bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size))) + image_b64 = convert_to_b64(image) + + data = { + "image": image_b64, + "prompt": bbox_str, + "tool": "few_shot_counting", + } + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data + + +def image_question_answering(image: np.ndarray, prompt: str) -> str: + """'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image. + It returns an answer to the question + + Parameters: + image (np.ndarray): The reference image used for the question + prompt (str): The question about the image + + Returns: + str: A string which is the answer to the given prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}. + + Example + ------- + >>> image_question_answering(image, 'What is the cat doing ?') + 'drinking milk' + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "prompt": prompt, + "tool": "image_question_answering", + } + + answer = _send_inference_request(data, "tools") + return answer["text"][0] # type: ignore + + +def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: + """'clip' is a tool that can classify an image given a list of input classes or tags. + It returns the same list of the input classes along with their probability scores based on image content. + + Parameters: + image (np.ndarray): The image to classify or tag + classes (List[str]): The list of classes or tags that is associated with the image + + Returns: + Dict[str, Any]: A dictionary containing the labels and scores. One dictionary contains a list of given labels and other a list of scores. + + Example + ------- + >>> clip(image, ['dog', 'cat', 'bird']) + {"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]}, + + """ + + image_b64 = convert_to_b64(image) + data = { + "prompt": ",".join(classes), + "image": image_b64, + "tool": "closed_set_image_classification", + } + resp_data = _send_inference_request(data, "tools") + resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] + return resp_data + + +def image_caption(image: np.ndarray) -> str: + """'image_caption' is a tool that can caption an image based on its contents. + It returns a text describing the image. + + Parameters: + image (np.ndarray): The image to caption + + Returns: + str: A string which is the caption for the given image. + + Example + ------- + >>> image_caption(image) + 'This image contains a cat sitting on a table with a bowl of milk.' + + """ + + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "image_captioning", + } + + answer = _send_inference_request(data, "tools") + return answer["text"][0] # type: ignore + + def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float: """'closest_mask_distance' calculates the closest distance between two masks. @@ -504,6 +661,11 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: grounding_sam, extract_frames, ocr, + clip, + zero_shot_counting, + visual_prompt_counting, + image_question_answering, + image_caption, closest_mask_distance, closest_box_distance, save_json, diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index 1fb68a3f..5d638c62 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -104,15 +104,20 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: """ if data is None: raise ValueError(f"Invalid input image: {data}. Input image can't be None.") + if isinstance(data, (str, Path)): data = Image.open(data) + elif isinstance(data, np.ndarray): + data = Image.fromarray(data) + if isinstance(data, Image.Image): buffer = BytesIO() data.convert("RGB").save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") else: - arr_bytes = data.tobytes() - return base64.b64encode(arr_bytes).decode("utf-8") + raise ValueError( + f"Invalid input image: {data}. Input image must be a PIL Image or a numpy array." + ) def denormalize_bbox(