From 10b3151a4d549f0a4d8a8e61cdc52d65e157f2d1 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Wed, 28 Aug 2024 11:42:48 -0700 Subject: [PATCH 01/12] Adding countgd as default counting tool --- tests/integ/test_tools.py | 18 ++++ .../agent/vision_agent_coder_prompts.py | 9 +- vision_agent/lmm/lmm.py | 3 - vision_agent/tools/__init__.py | 2 + vision_agent/tools/tools.py | 98 ++++++++++++++++++- 5 files changed, 120 insertions(+), 10 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index afa9dcb4..eae02205 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -24,6 +24,8 @@ ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, + countgd_counting, + countgd_example_based_counting, ocr, owl_v2, template_match, @@ -186,6 +188,22 @@ def test_loca_visual_prompt_counting() -> None: assert result["count"] == 25 +def test_countgd_counting() -> None: + img = ski.data.coins() + + result = countgd_counting(image=img, prompt="coin") + assert result["count"] == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompt=[[85, 106, 122, 145]], + image=img, + ) + assert result["count"] == 24 + + def test_git_vqa_v2() -> None: img = ski.data.rocket() result = git_vqa_v2( diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index c68f73fe..b4c8a9bf 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -81,20 +81,19 @@ - Count the number of detected objects labeled as 'person'. plan3: - Load the image from the provided file path 'image.jpg'. -- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people. +- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. ```python -from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting +from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting image = load_image("image.jpg") owl_v2_out = owl_v2("person", image) gsam_out = grounding_sam("person", image) gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out] -loca_out = loca_zero_shot_counting(image) -loca_out = loca_out["count"] +cgd_out = countgd_counting(image) -final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}} +final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} print(final_out) ``` """ diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e78a0593..f0641958 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -276,9 +276,6 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: T.grounding_sam(params["prompt"], x) - def generate_zero_shot_counter(self, question: str) -> Callable: - return T.loca_zero_shot_counting - def generate_image_qa_tool(self, question: str) -> Callable: return lambda x: T.git_vqa_v2(question, x) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index a90b7181..e2e1b160 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -37,6 +37,8 @@ load_image, loca_visual_prompt_counting, loca_zero_shot_counting, + countgd_counting, + countgd_example_based_counting, ocr, overlay_bounding_boxes, overlay_heat_map, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 594fcf6d..3478e0a8 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -467,6 +467,8 @@ def loca_visual_prompt_counting( Parameters: image (np.ndarray): The image that contains lot of instances of a single object + visual_prompt (Dict[str, List[float]]): Bounding box of the object in format + [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided. Returns: Dict[str, Any]: A dictionary containing the key 'count' and the count as a @@ -499,6 +501,99 @@ def loca_visual_prompt_counting( return resp_data +def countgd_counting( + prompt: str, + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_counting' is a tool that can precisely count multiple instances of an + object given a text prompt. It returns a list of bounding boxes with normalized + coordinates, label names and associated confidence scores. + + Parameters: + prompt (str): The object that needs to be counted. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_counting("flower", image) + [ + {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { + "text": prompt, + "visual_prompts": [], + "box_threshold": box_threshold, + "function_name": "countgd_counting", + } + data: Dict[str, Any] = send_inference_request( + payload, "countgd_counting", files=files, v2=True + ) + return data + + +def countgd_example_based_counting( + visual_prompts: List[List[float]], + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_example_based_counting' is a tool that can precisely count multiple + instances of an object given few visual example prompts. It returns a list of bounding + boxes with normalized coordinates, label names and associated confidence scores. + + Parameters: + visual_prompts (List[List[float]]): Bounding boxes of the object in format + [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_example_based_counting(visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], image=image) + [ + {'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { + "text": "", + "visual_prompts": visual_prompts, + "box_threshold": box_threshold, + "function_name": "countgd_example_based_counting", + } + data: Dict[str, Any] = send_inference_request( + payload, "countgd_example_based_counting", files=files, v2=True + ) + return data + + def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: """'florence2_roberta_vqa' is a tool that takes an image and analyzes its contents, generates detailed captions and then tries to answer the given @@ -1657,8 +1752,7 @@ def florencev2_fine_tuned_object_detection( clip, vit_image_classification, vit_nsfw_classification, - loca_zero_shot_counting, - loca_visual_prompt_counting, + countgd_counting, florence2_image_caption, florence2_ocr, florence2_sam2_image, From 8ee040e280edd046a01764182bc59749a9d42a14 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Wed, 28 Aug 2024 11:46:09 -0700 Subject: [PATCH 02/12] fix mypy errors --- tests/integ/test_tools.py | 2 +- vision_agent/tools/tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index eae02205..0308361d 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -198,7 +198,7 @@ def test_countgd_counting() -> None: def test_countgd_example_based_counting() -> None: img = ski.data.coins() result = countgd_example_based_counting( - visual_prompt=[[85, 106, 122, 145]], + visual_prompts=[[85, 106, 122, 145]], image=img, ) assert result["count"] == 24 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3478e0a8..cfc43534 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -541,7 +541,7 @@ def countgd_counting( "box_threshold": box_threshold, "function_name": "countgd_counting", } - data: Dict[str, Any] = send_inference_request( + data: List[Dict[str, Any]] = send_inference_request( payload, "countgd_counting", files=files, v2=True ) return data @@ -588,7 +588,7 @@ def countgd_example_based_counting( "box_threshold": box_threshold, "function_name": "countgd_example_based_counting", } - data: Dict[str, Any] = send_inference_request( + data: List[Dict[str, Any]] = send_inference_request( payload, "countgd_example_based_counting", files=files, v2=True ) return data From cc84b6974e38f1efc11f195352efa5f7e054d0cc Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Wed, 28 Aug 2024 16:46:29 -0700 Subject: [PATCH 03/12] added viz for counting tool --- vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 69 +++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index e2e1b160..a50247b4 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -43,6 +43,7 @@ overlay_bounding_boxes, overlay_heat_map, overlay_segmentation_masks, + overlay_counting_results, owl_v2, save_image, save_json, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index cfc43534..819de37c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -11,7 +11,7 @@ import numpy as np import requests from moviepy.editor import ImageSequenceClip -from PIL import Image, ImageDraw, ImageFont +from PIL import Image, ImageDraw, ImageFont, ImageEnhance from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore @@ -1632,6 +1632,71 @@ def overlay_heat_map( return np.array(combined) +def overlay_counting_results( + image: np.ndarray, instances: List[Dict[str, Any]] +) -> np.ndarray: + """'overlay_counting_results' is a utility function that displays counting results on + an image. + + Parameters: + image (np.ndarray): The image to display the bounding boxes on. + instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding + box information of each instance + + Returns: + np.ndarray: The image with the instance_id dislpayed + + Example + ------- + >>> image_with_bboxes = overlay_counting_results( + image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}], + ) + """ + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") + color = (158, 218, 229) + + width, height = pil_image.size + fontsize = max(10, int(min(width, height) / 80)) + pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5) + draw = ImageDraw.Draw(pil_image) + font = ImageFont.load_default(size=fontsize) + + for i, elt in enumerate(instances): + label = f"{i}" + box = elt["bbox"] + + # denormalize the box if it is normalized + box = denormalize_bbox(box, (height, width)) + x0, y0, x1, y1 = box + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + + text_box = draw.textbbox( + (cx, cy), text=label, font=font, align="center", anchor="mm" + ) + + # Calculate the offset to center the text within the bounding box + text_width = text_box[2] - text_box[0] + text_height = text_box[3] - text_box[1] + text_x0 = cx - text_width / 2 + text_y0 = cy - text_height / 2 + text_x1 = cx + text_width / 2 + text_y1 = cy + text_height / 2 + + # Draw the rectangle encapsulating the text + draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color) + + # Draw the text at the center of the bounding box + draw.text( + (text_x0, text_y0), + label, + fill="black", + font=font, + anchor="lt", + ) + + return np.array(pil_image) + + # TODO: add this function to the imports so that is picked in the agent def florencev2_fine_tuning(bboxes: List[Dict[str, Any]], task: str) -> UUID: """'florencev2_fine_tuning' is a tool that fine-tune florencev2 to be able @@ -1775,6 +1840,7 @@ def florencev2_fine_tuned_object_detection( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, + overlay_counting_results, ] TOOLS = FUNCTION_TOOLS + UTIL_TOOLS @@ -1792,5 +1858,6 @@ def florencev2_fine_tuned_object_detection( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, + overlay_counting_results, ] ) From ba594e4afb076190f9485e560b279b2414be88ea Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Fri, 30 Aug 2024 16:22:07 -0300 Subject: [PATCH 04/12] adjust call --- vision_agent/tools/tools.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 819de37c..c4f3b8df 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -533,16 +533,15 @@ def countgd_counting( {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ - buffer_bytes = numpy_to_bytes(image) - files = [("image", buffer_bytes)] + image_b64 = convert_to_b64(image) payload = { - "text": prompt, - "visual_prompts": [], + "image": image_b64, + "prompt": prompt, "box_threshold": box_threshold, - "function_name": "countgd_counting", } + metadata_payload = {"function_name": "countgd_counting"} data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd_counting", files=files, v2=True + payload, "countgd", v2=True, metadata_payload=metadata_payload ) return data @@ -572,7 +571,10 @@ def countgd_example_based_counting( Example ------- - >>> countgd_example_based_counting(visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], image=image) + >>> countgd_example_based_counting( + visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], + image=image + ) [ {'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}, {'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5}, @@ -580,16 +582,15 @@ def countgd_example_based_counting( {'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ - buffer_bytes = numpy_to_bytes(image) - files = [("image", buffer_bytes)] + image_b64 = convert_to_b64(image) payload = { - "text": "", + "image": image_b64, "visual_prompts": visual_prompts, "box_threshold": box_threshold, - "function_name": "countgd_example_based_counting", } + metadata_payload = {"function_name": "countgd_example_based_counting"} data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd_example_based_counting", files=files, v2=True + payload, "countgd", v2=True, metadata_payload=metadata_payload ) return data From 7d64451c34cf076001511660926d6f53088f4afd Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Sun, 1 Sep 2024 06:39:53 -0700 Subject: [PATCH 05/12] fix bbox coords outside the image, countgd return types --- vision_agent/tools/tool_utils.py | 4 ++-- vision_agent/tools/tools.py | 5 ++++- vision_agent/utils/image_utils.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 185563a4..2a260c41 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,7 +1,7 @@ import inspect import logging import os -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union import pandas as pd from IPython.display import display @@ -34,7 +34,7 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: +) -> Union[Dict[str, Any], List[Dict[str, Any]]]: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload try: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index c4f3b8df..a4eee6ac 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1660,7 +1660,10 @@ def overlay_counting_results( fontsize = max(10, int(min(width, height) / 80)) pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5) draw = ImageDraw.Draw(pil_image) - font = ImageFont.load_default(size=fontsize) + font = ImageFont.truetype( + str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), + fontsize, + ) for i, elt in enumerate(instances): label = f"{i}" diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index d2bc8a6d..9c39be42 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -181,7 +181,7 @@ def denormalize_bbox( raise ValueError("Bounding box must be of length 4.") arr = np.array(bbox) - if np.all((arr >= 0) & (arr <= 1)): + if np.all((arr[:2] >= 0) & (arr[:2] <= 1)): x1, y1, x2, y2 = bbox x1 = round(x1 * image_size[1]) y1 = round(y1 * image_size[0]) From 18cce4fc31052ae5085abbccc414a6979468946f Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Sun, 1 Sep 2024 06:52:15 -0700 Subject: [PATCH 06/12] fix return values from countgd endpoint --- vision_agent/tools/tool_utils.py | 4 ++-- vision_agent/tools/tools.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 2a260c41..185563a4 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,7 +1,7 @@ import inspect import logging import os -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd from IPython.display import display @@ -34,7 +34,7 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Union[Dict[str, Any], List[Dict[str, Any]]]: +) -> Dict[str, Any]: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload try: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index a4eee6ac..08bf0370 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -540,10 +540,10 @@ def countgd_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_counting"} - data: List[Dict[str, Any]] = send_inference_request( + resp: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload - ) - return data + ) # type: ignore + return resp["data"] def countgd_example_based_counting( @@ -589,10 +589,10 @@ def countgd_example_based_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_example_based_counting"} - data: List[Dict[str, Any]] = send_inference_request( + resp: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload - ) - return data + ) # type: ignore + return resp["data"] def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: From 7dd6a3777a654434881060c7d8e94cb620eb81a2 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Sun, 1 Sep 2024 22:30:37 -0700 Subject: [PATCH 07/12] correct output format for cgd --- tests/integ/test_tools.py | 5 ++--- vision_agent/tools/tools.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 0308361d..cbc5eeb8 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -190,9 +190,8 @@ def test_loca_visual_prompt_counting() -> None: def test_countgd_counting() -> None: img = ski.data.coins() - result = countgd_counting(image=img, prompt="coin") - assert result["count"] == 24 + assert len(result) == 24 def test_countgd_example_based_counting() -> None: @@ -201,7 +200,7 @@ def test_countgd_example_based_counting() -> None: visual_prompts=[[85, 106, 122, 145]], image=img, ) - assert result["count"] == 24 + assert len(result) == 24 def test_git_vqa_v2() -> None: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 08bf0370..1ad6ea11 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -540,10 +540,11 @@ def countgd_counting( "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_counting"} - resp: List[Dict[str, Any]] = send_inference_request( + resp_data: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload ) # type: ignore - return resp["data"] + + return resp_data def countgd_example_based_counting( @@ -583,16 +584,20 @@ def countgd_example_based_counting( ] """ image_b64 = convert_to_b64(image) + visual_prompts = [ + denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts + ] payload = { "image": image_b64, "visual_prompts": visual_prompts, "box_threshold": box_threshold, } metadata_payload = {"function_name": "countgd_example_based_counting"} - resp: List[Dict[str, Any]] = send_inference_request( + resp_data: List[Dict[str, Any]] = send_inference_request( payload, "countgd", v2=True, metadata_payload=metadata_payload ) # type: ignore - return resp["data"] + + return resp_data def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: From d51344bcc2d6e046d4b0153d1eb363581d778d01 Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Mon, 2 Sep 2024 20:43:46 -0300 Subject: [PATCH 08/12] adapt countgd + dev integration test --- .github/workflows/ci_cd.yml | 7 ++ tests/integration_dev/__init__.py | 0 tests/integration_dev/test_tools.py | 21 ++++ vision_agent/tools/tool_utils.py | 144 ++++++++++++++++++---------- vision_agent/tools/tools.py | 63 ++++++------ vision_agent/tools/tools_types.py | 19 +++- 6 files changed, 171 insertions(+), 83 deletions(-) create mode 100644 tests/integration_dev/__init__.py create mode 100644 tests/integration_dev/test_tools.py diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index d41e7592..3576e10c 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -1,10 +1,14 @@ name: CI + on: push: branches: [ main ] pull_request: branches: [ main ] +env: + LANDINGAI_DEV_API_KEY: ${{ secrets.LANDINGAI_DEV_API_KEY }} + jobs: unit_test: name: Test @@ -79,6 +83,9 @@ jobs: - name: Test with pytest run: | poetry run pytest -v tests/integ + - name: Test with pytest, dev env + run: | + LANDINGAI_API_KEY=$LANDINGAI_DEV_API_KEY LANDINGAI_URL=https://api.dev.landing.ai poetry run pytest -v tests/integration_dev release: name: Release diff --git a/tests/integration_dev/__init__.py b/tests/integration_dev/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py new file mode 100644 index 00000000..29262245 --- /dev/null +++ b/tests/integration_dev/test_tools.py @@ -0,0 +1,21 @@ +import skimage as ski + +from vision_agent.tools import ( + countgd_counting, + countgd_example_based_counting, +) + + +def test_countgd_counting() -> None: + img = ski.data.coins() + result = countgd_counting(image=img, prompt="coin") + assert len(result) == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompts=[[85, 106, 122, 145]], + image=img, + ) + assert len(result) == 24 diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 185563a4..30ac659b 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -1,6 +1,6 @@ +import os import inspect import logging -import os from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple import pandas as pd @@ -13,6 +13,7 @@ from vision_agent.utils.exceptions import RemoteToolCallFailed from vision_agent.utils.execute import Error, MimeType from vision_agent.utils.type_defs import LandingaiAPIKey +from vision_agent.tools.tools_types import BoundingBoxes _LOGGER = logging.getLogger(__name__) _LND_API_KEY = os.environ.get("LANDINGAI_API_KEY", LandingaiAPIKey().api_key) @@ -37,58 +38,55 @@ def send_inference_request( ) -> Dict[str, Any]: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload - try: - if runtime_tag := os.environ.get("RUNTIME_TAG", ""): - payload["runtime_tag"] = runtime_tag + if runtime_tag := os.environ.get("RUNTIME_TAG", ""): + payload["runtime_tag"] = runtime_tag + + url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" + if "TOOL_ENDPOINT_URL" in os.environ: + url = os.environ["TOOL_ENDPOINT_URL"] + + headers = {"apikey": _LND_API_KEY} + if "TOOL_ENDPOINT_AUTH" in os.environ: + headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] + headers.pop("apikey") + + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) - url = f"{_LND_API_URL_v2 if v2 else _LND_API_URL}/{endpoint_name}" - if "TOOL_ENDPOINT_URL" in os.environ: - url = os.environ["TOOL_ENDPOINT_URL"] + function_name = "unknown" + if "function_name" in payload: + function_name = payload["function_name"] + elif metadata_payload is not None and "function_name" in metadata_payload: + function_name = metadata_payload["function_name"] - tool_call_trace = ToolCallTrace( - endpoint_url=url, - request=payload, - response={}, - error=None, - ) - headers = {"apikey": _LND_API_KEY} - if "TOOL_ENDPOINT_AUTH" in os.environ: - headers["Authorization"] = os.environ["TOOL_ENDPOINT_AUTH"] - headers.pop("apikey") - - session = _create_requests_session( - url=url, - num_retry=3, - headers=headers, - ) + response = _call_post(url, payload, session, files, function_name) - if files is not None: - res = session.post(url, data=payload, files=files) - else: - res = session.post(url, json=payload) - if res.status_code != 200: - tool_call_trace.error = Error( - name="RemoteToolCallFailed", - value=f"{res.status_code} - {res.text}", - traceback_raw=[], - ) - _LOGGER.error(f"Request failed: {res.status_code} {res.text}") - # TODO: function_name should be in metadata_payload - function_name = "unknown" - if "function_name" in payload: - function_name = payload["function_name"] - elif metadata_payload is not None and "function_name" in metadata_payload: - function_name = metadata_payload["function_name"] - raise RemoteToolCallFailed(function_name, res.status_code, res.text) - - resp = res.json() - tool_call_trace.response = resp - # TODO: consider making the response schema the same between below two sources - return resp if "TOOL_ENDPOINT_AUTH" in os.environ else resp["data"] # type: ignore - finally: - trace = tool_call_trace.model_dump() - trace["type"] = "tool_call" - display({MimeType.APPLICATION_JSON: trace}, raw=True) + # TODO: consider making the response schema the same between below two sources + return response if "TOOL_ENDPOINT_AUTH" in os.environ else response["data"] + + +def send_task_inference_request( + payload: Dict[str, Any], + endpoint_name: str, + files: Optional[List[Tuple[Any, ...]]] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + url = f"{_LND_API_URL_v2}/{endpoint_name}" + headers = {"apikey": _LND_API_KEY} + session = _create_requests_session( + url=url, + num_retry=3, + headers=headers, + ) + + function_name = "unknown" + if metadata is not None and "function_name" in metadata: + function_name = metadata["function_name"] + response = _call_post(url, payload, session, files, function_name) + return response["data"] def _create_requests_session( @@ -195,3 +193,49 @@ def get_tools_info(funcs: List[Callable[..., Any]]) -> Dict[str, str]: data[func.__name__] = f"{func.__name__}{inspect.signature(func)}:\n{desc}" return data + + +def _call_post( + url: str, + payload: dict[str, Any], + session: Session, + files: Optional[List[Tuple[Any, ...]]] = None, + function_name: str = "unknown", +) -> dict[str, Any]: + try: + tool_call_trace = ToolCallTrace( + endpoint_url=url, + request=payload, + response={}, + error=None, + ) + + if files is not None: + response = session.post(url, data=payload, files=files) + else: + response = session.post(url, json=payload) + + if response.status_code != 200: + tool_call_trace.error = Error( + name="RemoteToolCallFailed", + value=f"{response.status_code} - {response.text}", + traceback_raw=[], + ) + _LOGGER.error(f"Request failed: {response.status_code} {response.text}") + raise RemoteToolCallFailed( + function_name, response.status_code, response.text + ) + + result = response.json() + tool_call_trace.response = result + return result + finally: + trace = tool_call_trace.model_dump() + trace["type"] = "tool_call" + display({MimeType.APPLICATION_JSON: trace}, raw=True) + + +def filter_bboxes_by_threshold( + bboxes: BoundingBoxes, threshold: float +) -> BoundingBoxes: + return list(map(lambda bbox: bbox["score"] >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 1ad6ea11..add35f5c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -22,6 +22,8 @@ get_tools_df, get_tools_info, send_inference_request, + send_task_inference_request, + filter_bboxes_by_threshold, ) from vision_agent.tools.tools_types import ( BboxInput, @@ -30,6 +32,7 @@ Florencev2FtRequest, JobStatus, PromptTask, + ODResponseData, ) from vision_agent.utils import extract_frames_from_video from vision_agent.utils.exceptions import FineTuneModelIsNotReady @@ -527,24 +530,22 @@ def countgd_counting( ------- >>> countgd_counting("flower", image) [ - {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, - {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, + {'score': 0.49, 'label': 'flower', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ - image_b64 = convert_to_b64(image) - payload = { - "image": image_b64, - "prompt": prompt, - "box_threshold": box_threshold, - } - metadata_payload = {"function_name": "countgd_counting"} - resp_data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd", v2=True, metadata_payload=metadata_payload - ) # type: ignore - - return resp_data + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = {"prompts": [prompt]} + metadata = {"function_name": "countgd_counting"} + resp_data: List[Dict[str, Any]] = send_task_inference_request( + payload, "text-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) def countgd_example_based_counting( @@ -577,27 +578,25 @@ def countgd_example_based_counting( image=image ) [ - {'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}, - {'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5}, - {'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52}, - {'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58}, + {'score': 0.49, 'label': 'object', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'object', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'object', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ - image_b64 = convert_to_b64(image) + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] visual_prompts = [ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts ] - payload = { - "image": image_b64, - "visual_prompts": visual_prompts, - "box_threshold": box_threshold, - } - metadata_payload = {"function_name": "countgd_example_based_counting"} - resp_data: List[Dict[str, Any]] = send_inference_request( - payload, "countgd", v2=True, metadata_payload=metadata_payload - ) # type: ignore - - return resp_data + payload = {"visual_prompts": json.loads(visual_prompts)} + metadata = {"function_name": "countgd_example_based_counting"} + resp_data: List[Dict[str, Any]] = send_task_inference_request( + payload, "visual-prompts-to-object-detection", files=files, metadata=metadata + ) + bboxes_per_frame = resp_data[0] + bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index aeb45c95..aa6f5f68 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -1,7 +1,8 @@ from uuid import UUID from enum import Enum -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Annotated +from annotated_types import Len from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo @@ -82,3 +83,19 @@ class JobStatus(str, Enum): SUCCEEDED = "SUCCEEDED" FAILED = "FAILED" STOPPED = "STOPPED" + + +BoundingBox = Annotated[list[int | float], Len(min_length=4, max_length=4)] + + +class ODResponseData(BaseModel): + label: str + score: float + bbox: BoundingBox = Field(alias="bounding_box") + + model_config = ConfigDict( + populate_by_name=True, + ) + + +BoundingBoxes = list[ODResponseData] From becc7f3f1c95aee507c3e17bc66fd9146d26faa0 Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Mon, 2 Sep 2024 20:47:08 -0300 Subject: [PATCH 09/12] add model --- vision_agent/tools/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index afd4f6bb..960f51f9 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -540,7 +540,7 @@ def countgd_counting( """ buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] - payload = {"prompts": [prompt]} + payload = {"prompts": [prompt], "model": "countgd"} metadata = {"function_name": "countgd_counting"} resp_data: List[Dict[str, Any]] = send_task_inference_request( payload, "text-to-object-detection", files=files, metadata=metadata @@ -591,7 +591,7 @@ def countgd_example_based_counting( visual_prompts = [ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts ] - payload = {"visual_prompts": json.loads(visual_prompts)} + payload = {"visual_prompts": json.loads(visual_prompts), "model": "countgd"} metadata = {"function_name": "countgd_example_based_counting"} resp_data: List[Dict[str, Any]] = send_task_inference_request( payload, "visual-prompts-to-object-detection", files=files, metadata=metadata From f8e05ee2559604f0fa63ee66c009dc6ed43b12e9 Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Mon, 2 Sep 2024 20:58:40 -0300 Subject: [PATCH 10/12] linter --- vision_agent/tools/tool_utils.py | 12 ++++++------ vision_agent/tools/tools.py | 24 ++++++++++++++---------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/vision_agent/tools/tool_utils.py b/vision_agent/tools/tool_utils.py index 30ac659b..a14443bd 100644 --- a/vision_agent/tools/tool_utils.py +++ b/vision_agent/tools/tool_utils.py @@ -35,7 +35,7 @@ def send_inference_request( files: Optional[List[Tuple[Any, ...]]] = None, v2: bool = False, metadata_payload: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: +) -> Any: # TODO: runtime_tag and function_name should be metadata_payload and now included # in the service payload if runtime_tag := os.environ.get("RUNTIME_TAG", ""): @@ -70,11 +70,11 @@ def send_inference_request( def send_task_inference_request( payload: Dict[str, Any], - endpoint_name: str, + task_name: str, files: Optional[List[Tuple[Any, ...]]] = None, metadata: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - url = f"{_LND_API_URL_v2}/{endpoint_name}" +) -> Any: + url = f"{_LND_API_URL_v2}/{task_name}" headers = {"apikey": _LND_API_KEY} session = _create_requests_session( url=url, @@ -201,7 +201,7 @@ def _call_post( session: Session, files: Optional[List[Tuple[Any, ...]]] = None, function_name: str = "unknown", -) -> dict[str, Any]: +) -> Any: try: tool_call_trace = ToolCallTrace( endpoint_url=url, @@ -238,4 +238,4 @@ def _call_post( def filter_bboxes_by_threshold( bboxes: BoundingBoxes, threshold: float ) -> BoundingBoxes: - return list(map(lambda bbox: bbox["score"] >= threshold, bboxes)) + return list(filter(lambda bbox: bbox.score >= threshold, bboxes)) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 960f51f9..e0961398 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -458,7 +458,7 @@ def loca_zero_shot_counting(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "loca_zero_shot_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -501,7 +501,7 @@ def loca_visual_prompt_counting( "bbox": list(map(int, denormalize_bbox(bbox, image_size))), "function_name": "loca_visual_prompt_counting", } - resp_data = send_inference_request(data, "loca", v2=True) + resp_data: dict[str, Any] = send_inference_request(data, "loca", v2=True) resp_data["heat_map"] = np.array(resp_data["heat_map"][0]).astype(np.uint8) return resp_data @@ -542,12 +542,13 @@ def countgd_counting( files = [("image", buffer_bytes)] payload = {"prompts": [prompt], "model": "countgd"} metadata = {"function_name": "countgd_counting"} - resp_data: List[Dict[str, Any]] = send_task_inference_request( + resp_data = send_task_inference_request( payload, "text-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] def countgd_example_based_counting( @@ -591,14 +592,15 @@ def countgd_example_based_counting( visual_prompts = [ denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts ] - payload = {"visual_prompts": json.loads(visual_prompts), "model": "countgd"} + payload = {"visual_prompts": json.dumps(visual_prompts), "model": "countgd"} metadata = {"function_name": "countgd_example_based_counting"} - resp_data: List[Dict[str, Any]] = send_task_inference_request( + resp_data = send_task_inference_request( payload, "visual-prompts-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] - return filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: @@ -746,7 +748,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: "tool": "closed_set_image_classification", "function_name": "clip", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -774,7 +776,7 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: "tool": "image_classification", "function_name": "vit_image_classification", } - resp_data = send_inference_request(data, "tools") + resp_data: dict[str, Any] = send_inference_request(data, "tools") resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]] return resp_data @@ -801,7 +803,9 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: "image": image_b64, "function_name": "vit_nsfw_classification", } - resp_data = send_inference_request(data, "nsfw-classification", v2=True) + resp_data: dict[str, Any] = send_inference_request( + data, "nsfw-classification", v2=True + ) resp_data["score"] = round(resp_data["score"], 4) return resp_data From 3630d871ac6b34eb38cf17f162c52bf703dbf547 Mon Sep 17 00:00:00 2001 From: Dayanne Fernandes Date: Mon, 2 Sep 2024 21:03:02 -0300 Subject: [PATCH 11/12] linter --- tests/integ/test_tools.py | 17 ----------------- vision_agent/tools/tools_types.py | 8 ++------ 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 2ebc1cab..bca1f6ea 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -24,8 +24,6 @@ ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, - countgd_counting, - countgd_example_based_counting, ocr, owl_v2, template_match, @@ -188,21 +186,6 @@ def test_loca_visual_prompt_counting() -> None: assert result["count"] == 25 -def test_countgd_counting() -> None: - img = ski.data.coins() - result = countgd_counting(image=img, prompt="coin") - assert len(result) == 24 - - -def test_countgd_example_based_counting() -> None: - img = ski.data.coins() - result = countgd_example_based_counting( - visual_prompts=[[85, 106, 122, 145]], - image=img, - ) - assert len(result) == 24 - - def test_git_vqa_v2() -> None: img = ski.data.rocket() result = git_vqa_v2( diff --git a/vision_agent/tools/tools_types.py b/vision_agent/tools/tools_types.py index 8f456eb6..af1e8ee9 100644 --- a/vision_agent/tools/tools_types.py +++ b/vision_agent/tools/tools_types.py @@ -1,8 +1,7 @@ from enum import Enum from uuid import UUID -from typing import List, Tuple, Optional, Annotated +from typing import List, Tuple, Optional, Union -from annotated_types import Len from pydantic import BaseModel, ConfigDict, Field, field_serializer, SerializationInfo @@ -85,13 +84,10 @@ class JobStatus(str, Enum): STOPPED = "STOPPED" -BoundingBox = Annotated[list[int | float], Len(min_length=4, max_length=4)] - - class ODResponseData(BaseModel): label: str score: float - bbox: BoundingBox = Field(alias="bounding_box") + bbox: Union[list[int], list[float]] = Field(alias="bounding_box") model_config = ConfigDict( populate_by_name=True, From b69fd5ec1812406ec941efebbb3d67cf6ab794c4 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Tue, 3 Sep 2024 16:32:25 -0700 Subject: [PATCH 12/12] fixed keys in the example string, add suppot for multi-class --- vision_agent/tools/tools.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index e0961398..8012e60d 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -532,14 +532,15 @@ def countgd_counting( ------- >>> countgd_counting("flower", image) [ - {'score': 0.49, 'label': 'flower', 'bounding_box': [0.1, 0.11, 0.35, 0.4]}, - {'score': 0.68, 'label': 'flower', 'bounding_box': [0.2, 0.21, 0.45, 0.5}, - {'score': 0.78, 'label': 'flower', 'bounding_box': [0.3, 0.35, 0.48, 0.52}, - {'score': 0.98, 'label': 'flower', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, + {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] + prompt = prompt.replace(", ", " .") payload = {"prompts": [prompt], "model": "countgd"} metadata = {"function_name": "countgd_counting"} resp_data = send_task_inference_request(