From 10b3151a4d549f0a4d8a8e61cdc52d65e157f2d1 Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Wed, 28 Aug 2024 11:42:48 -0700 Subject: [PATCH] Adding countgd as default counting tool --- tests/integ/test_tools.py | 18 ++++ .../agent/vision_agent_coder_prompts.py | 9 +- vision_agent/lmm/lmm.py | 3 - vision_agent/tools/__init__.py | 2 + vision_agent/tools/tools.py | 98 ++++++++++++++++++- 5 files changed, 120 insertions(+), 10 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index afa9dcb4..eae02205 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -24,6 +24,8 @@ ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, + countgd_counting, + countgd_example_based_counting, ocr, owl_v2, template_match, @@ -186,6 +188,22 @@ def test_loca_visual_prompt_counting() -> None: assert result["count"] == 25 +def test_countgd_counting() -> None: + img = ski.data.coins() + + result = countgd_counting(image=img, prompt="coin") + assert result["count"] == 24 + + +def test_countgd_example_based_counting() -> None: + img = ski.data.coins() + result = countgd_example_based_counting( + visual_prompt=[[85, 106, 122, 145]], + image=img, + ) + assert result["count"] == 24 + + def test_git_vqa_v2() -> None: img = ski.data.rocket() result = git_vqa_v2( diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index c68f73fe..b4c8a9bf 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -81,20 +81,19 @@ - Count the number of detected objects labeled as 'person'. plan3: - Load the image from the provided file path 'image.jpg'. -- Use the 'loca_zero_shot_counting' tool to count the dominant foreground object, which in this case is people. +- Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. ```python -from vision_agent.tools import load_image, owl_v2, grounding_sam, loca_zero_shot_counting +from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting image = load_image("image.jpg") owl_v2_out = owl_v2("person", image) gsam_out = grounding_sam("person", image) gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out] -loca_out = loca_zero_shot_counting(image) -loca_out = loca_out["count"] +cgd_out = countgd_counting(image) -final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "loca_zero_shot_counting": loca_out}} +final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} print(final_out) ``` """ diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index e78a0593..f0641958 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -276,9 +276,6 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: T.grounding_sam(params["prompt"], x) - def generate_zero_shot_counter(self, question: str) -> Callable: - return T.loca_zero_shot_counting - def generate_image_qa_tool(self, question: str) -> Callable: return lambda x: T.git_vqa_v2(question, x) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index a90b7181..e2e1b160 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -37,6 +37,8 @@ load_image, loca_visual_prompt_counting, loca_zero_shot_counting, + countgd_counting, + countgd_example_based_counting, ocr, overlay_bounding_boxes, overlay_heat_map, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 594fcf6d..3478e0a8 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -467,6 +467,8 @@ def loca_visual_prompt_counting( Parameters: image (np.ndarray): The image that contains lot of instances of a single object + visual_prompt (Dict[str, List[float]]): Bounding box of the object in format + [xmin, ymin, xmax, ymax]. Only 1 bounding box can be provided. Returns: Dict[str, Any]: A dictionary containing the key 'count' and the count as a @@ -499,6 +501,99 @@ def loca_visual_prompt_counting( return resp_data +def countgd_counting( + prompt: str, + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_counting' is a tool that can precisely count multiple instances of an + object given a text prompt. It returns a list of bounding boxes with normalized + coordinates, label names and associated confidence scores. + + Parameters: + prompt (str): The object that needs to be counted. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_counting("flower", image) + [ + {'score': 0.49, 'label': 'flower', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'flower', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'flower', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { + "text": prompt, + "visual_prompts": [], + "box_threshold": box_threshold, + "function_name": "countgd_counting", + } + data: Dict[str, Any] = send_inference_request( + payload, "countgd_counting", files=files, v2=True + ) + return data + + +def countgd_example_based_counting( + visual_prompts: List[List[float]], + image: np.ndarray, + box_threshold: float = 0.23, +) -> List[Dict[str, Any]]: + """'countgd_example_based_counting' is a tool that can precisely count multiple + instances of an object given few visual example prompts. It returns a list of bounding + boxes with normalized coordinates, label names and associated confidence scores. + + Parameters: + visual_prompts (List[List[float]]): Bounding boxes of the object in format + [xmin, ymin, xmax, ymax]. Upto 3 bounding boxes can be provided. + image (np.ndarray): The image that contains multiple instances of the object. + box_threshold (float, optional): The threshold for detection. Defaults + to 0.23. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the score, label, and + bounding box of the detected objects with normalized coordinates between 0 + and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the + top-left and xmax and ymax are the coordinates of the bottom-right of the + bounding box. + + Example + ------- + >>> countgd_example_based_counting(visual_prompts=[[0.1, 0.1, 0.4, 0.42], [0.2, 0.3, 0.25, 0.35]], image=image) + [ + {'score': 0.49, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.68, 'label': 'object', 'bbox': [0.2, 0.21, 0.45, 0.5}, + {'score': 0.78, 'label': 'object', 'bbox': [0.3, 0.35, 0.48, 0.52}, + {'score': 0.98, 'label': 'object', 'bbox': [0.44, 0.24, 0.49, 0.58}, + ] + """ + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { + "text": "", + "visual_prompts": visual_prompts, + "box_threshold": box_threshold, + "function_name": "countgd_example_based_counting", + } + data: Dict[str, Any] = send_inference_request( + payload, "countgd_example_based_counting", files=files, v2=True + ) + return data + + def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str: """'florence2_roberta_vqa' is a tool that takes an image and analyzes its contents, generates detailed captions and then tries to answer the given @@ -1657,8 +1752,7 @@ def florencev2_fine_tuned_object_detection( clip, vit_image_classification, vit_nsfw_classification, - loca_zero_shot_counting, - loca_visual_prompt_counting, + countgd_counting, florence2_image_caption, florence2_ocr, florence2_sam2_image,