From 5f11aea79387e75eb59fad35bf500ef3dfbf7bd2 Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:09:48 -0700 Subject: [PATCH] Add Count tools (#56) * Adding counting tools to vision agent * fixed heatmap overlay and addressesessed PR comments * adding the counting tool to take both absolute coordinate and normalized coordinates, refactoring code, adding llm generate counter tool * fix linting --- README.md | 5 +- vision_agent/agent/vision_agent.py | 43 +++++--- vision_agent/image_utils.py | 101 +++++++++++++++++- vision_agent/llm/llm.py | 4 + vision_agent/lmm/lmm.py | 4 + vision_agent/tools/__init__.py | 2 + vision_agent/tools/tools.py | 163 +++++++++++++++++++++++------ 7 files changed, 273 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 741b8ff2..c09e5823 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Vision Agent is a library that helps you utilize agent frameworks for your vision tasks. Many current vision problems can easily take hours or days to solve, you need to find the -right model, figure out how to use it, possibly write programming logic around it to +right model, figure out how to use it, possibly write programming logic around it to accomplish the task you want or even more expensive, train your own model. Vision Agent aims to provide an in-seconds experience by allowing users to describe their problem in text and utilizing agent frameworks to solve the task for them. Check out our discord @@ -108,6 +108,9 @@ you. For example: | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. | | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. | | ExtractFrames | ExtractFrames extracts frames with motion from a video. | +| ExtractFrames | ExtractFrames extracts frames with motion from a video. | +| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image | +| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt | It also has a basic set of calculate tools such as add, subtract, multiply and divide. diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 3287a174..b02cdf72 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,7 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks +from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -336,7 +336,9 @@ def _handle_viz_tools( for param, call_result in zip(parameters, tool_result["call_results"]): # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict) or "bboxes" not in call_result: + if not isinstance(call_result, dict) or ( + "bboxes" not in call_result and "masks" not in call_result + ): return image_to_data # if the call was successful, then we can add the image data @@ -349,11 +351,12 @@ def _handle_viz_tools( "scores": [], } - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) + image_to_data[image]["bboxes"].extend(call_result.get("bboxes", [])) + image_to_data[image]["labels"].extend(call_result.get("labels", [])) + image_to_data[image]["scores"].extend(call_result.get("scores", [])) + image_to_data[image]["masks"].extend(call_result.get("masks", [])) + if "mask_shape" in call_result: + image_to_data[image]["mask_shape"] = call_result["mask_shape"] return image_to_data @@ -367,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "grounding_dino_", "extract_frames_", "dinov_", + "zero_shot_counting_", + "visual_prompt_counting_", ]: continue @@ -379,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] for image_str in image_to_data: image_path = Path(image_str) image_data = image_to_data[image_str] - image = overlay_masks(image_path, image_data) - image = overlay_bboxes(image, image_data) + if "_counting_" in tool_result["tool_name"]: + image = overlay_heat_map(image_path, image_data) + else: + image = overlay_masks(image_path, image_data) + image = overlay_bboxes(image, image_data) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: image.save(f.name) visualized_images.append(f.name) @@ -484,11 +492,21 @@ def chat_with_workflow( if image: question += f" Image name: {image}" if reference_data: - if not ("image" in reference_data and "mask" in reference_data): + if not ( + "image" in reference_data + and ("mask" in reference_data or "bbox" in reference_data) + ): raise ValueError( - f"Reference data must contain 'image' and 'mask'. but got {reference_data}" + f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}" ) - question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}" + visual_prompt_data = ( + f"Reference mask: {reference_data['mask']}" + if "mask" in reference_data + else f"Reference bbox: {reference_data['bbox']}" + ) + question += ( + f" Reference image: {reference_data['image']}, {visual_prompt_data}" + ) reflections = "" final_answer = "" @@ -531,7 +549,6 @@ def chat_with_workflow( final_answer = answer_summarize( self.answer_model, question, answers, reflections ) - visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) if len(visualized_output) > 0: diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index f36a2033..43da645f 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -4,7 +4,7 @@ from importlib import resources from io import BytesIO from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, List import numpy as np from PIL import Image, ImageDraw, ImageFont @@ -34,6 +34,35 @@ ] +def normalize_bbox( + bbox: List[Union[int, float]], image_size: Tuple[int, ...] +) -> List[float]: + r"""Normalize the bounding box coordinates to be between 0 and 1.""" + x1, y1, x2, y2 = bbox + x1 = round(x1 / image_size[1], 2) + y1 = round(y1 / image_size[0], 2) + x2 = round(x2 / image_size[1], 2) + y2 = round(y2 / image_size[0], 2) + return [x1, y1, x2, y2] + + +def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: + r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background. + + Parameters: + mask_rle: Run-length as string formated (start length) + shape: The (height, width) of array to return + """ + s = mask_rle.split() + starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + starts -= 1 + ends = starts + lengths + img = np.zeros(shape[0] * shape[1], dtype=np.uint8) + for lo, hi in zip(starts, ends): + img[lo:hi] = 1 + return img.reshape(shape) + + def b64_to_pil(b64_str: str) -> ImageType: r"""Convert a base64 string to a PIL Image. @@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: return base64.b64encode(arr_bytes).decode("utf-8") +def denormalize_bbox( + bbox: List[Union[int, float]], image_size: Tuple[int, ...] +) -> List[float]: + r"""DeNormalize the bounding box coordinates so that they are in absolute values.""" + + if len(bbox) != 4: + raise ValueError("Bounding box must be of length 4.") + + arr = np.array(bbox) + if np.all((arr >= 0) & (arr <= 1)): + x1, y1, x2, y2 = bbox + x1 = round(x1 * image_size[1]) + y1 = round(y1 * image_size[0]) + x2 = round(x2 * image_size[1]) + y2 = round(y2 * image_size[0]) + return [x1, y1, x2, y2] + else: + return bbox + + def overlay_bboxes( image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict ) -> ImageType: @@ -103,6 +152,9 @@ def overlay_bboxes( elif isinstance(image, np.ndarray): image = Image.fromarray(image) + if "bboxes" not in bboxes: + return image.convert("RGB") + color = { label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"])) } @@ -114,8 +166,6 @@ def overlay_bboxes( str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), fontsize, ) - if "bboxes" not in bboxes: - return image.convert("RGB") for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]): box = [ @@ -150,11 +200,15 @@ def overlay_masks( elif isinstance(image, np.ndarray): image = Image.fromarray(image) + if "masks" not in masks: + return image.convert("RGB") + + if "labels" not in masks: + masks["labels"] = [""] * len(masks["masks"]) + color = { label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"])) } - if "masks" not in masks: - return image.convert("RGB") for label, mask in zip(masks["labels"], masks["masks"]): if isinstance(mask, str): @@ -164,3 +218,40 @@ def overlay_masks( mask_img = Image.fromarray(np_mask.astype(np.uint8)) image = Image.alpha_composite(image.convert("RGBA"), mask_img) return image.convert("RGB") + + +def overlay_heat_map( + image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8 +) -> ImageType: + r"""Plots heat map on to an image. + + Parameters: + image: the input image + masks: the heatmap to overlay + alpha: the transparency of the overlay + + Returns: + The image with the heatmap overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + if "masks" not in masks: + return image.convert("RGB") + + # Only one heat map per image, so no need to loop through masks + image = image.convert("L") + + if isinstance(masks["masks"][0], str): + mask = b64_to_pil(masks["masks"][0]) + + overlay = Image.new("RGBA", mask.size) + odraw = ImageDraw.Draw(overlay) + odraw.bitmap( + (0, 0), mask, fill=(255, 0, 0, round(alpha * 255)) + ) # fill=(R, G, B, Alpha) + combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size)) + + return combined.convert("RGB") diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 9352f58b..3f83c269 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -11,6 +11,7 @@ SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ZeroShotCounting, ) @@ -127,6 +128,9 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + def generate_zero_shot_counter(self, question: str) -> Callable: + return lambda x: ZeroShotCounting()(**{"image": x}) + class AzureOpenAILLM(OpenAILLM): def __init__( diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 615804ed..06ce94a2 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -15,6 +15,7 @@ SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ZeroShotCounting, ) _LOGGER = logging.getLogger(__name__) @@ -272,6 +273,9 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + def generate_zero_shot_counter(self, question: str) -> Callable: + return lambda x: ZeroShotCounting()(**{"image": x}) + class AzureOpenAILMM(OpenAILMM): def __init__( diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 1c1c6e73..38bb08d4 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -11,6 +11,8 @@ GroundingDINO, GroundingSAM, ImageCaption, + ZeroShotCounting, + VisualPromptCounting, SegArea, SegIoU, Tool, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 7657a362..c964a895 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -9,7 +9,13 @@ from PIL import Image from PIL.Image import Image as ImageType -from vision_agent.image_utils import convert_to_b64, get_image_size +from vision_agent.image_utils import ( + convert_to_b64, + get_image_size, + rle_decode, + normalize_bbox, + denormalize_bbox, +) from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey @@ -18,35 +24,6 @@ _LND_API_URL = "https://api.dev.landing.ai/v1/agent" -def normalize_bbox( - bbox: List[Union[int, float]], image_size: Tuple[int, ...] -) -> List[float]: - r"""Normalize the bounding box coordinates to be between 0 and 1.""" - x1, y1, x2, y2 = bbox - x1 = round(x1 / image_size[1], 2) - y1 = round(y1 / image_size[0], 2) - x2 = round(x2 / image_size[1], 2) - y2 = round(y2 / image_size[0], 2) - return [x1, y1, x2, y2] - - -def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: - r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background. - - Parameters: - mask_rle: Run-length as string formated (start length) - shape: The (height, width) of array to return - """ - s = mask_rle.split() - starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] - starts -= 1 - ends = starts + lengths - img = np.zeros(shape[0] * shape[1], dtype=np.uint8) - for lo, hi in zip(starts, ends): - img[lo:hi] = 1 - return img.reshape(shape) - - class Tool(ABC): name: str description: str @@ -489,6 +466,130 @@ def __call__( return rets +class ZeroShotCounting(Tool): + r"""ZeroShotCounting is a tool that can count total number of instances of an object + present in an image belonging to same class without a text or visual prompt. + + Example + ------- + >>> import vision_agent as va + >>> zshot_count = va.tools.ZeroShotCounting() + >>> zshot_count("image1.jpg") + {'count': 45} + """ + + name = "zero_shot_counting_" + description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt." + + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you count the lids in the image ? Image name: lids.jpg", + "parameters": {"image": "lids.jpg"}, + }, + { + "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg", + "parameters": {"image": "tray.jpg"}, + }, + { + "scenario": "Can you build me an object counting tool ? Image name: shirts.jpg", + "parameters": { + "image": "shirts.jpg", + }, + }, + ], + } + + # TODO: Add support for input multiple images, which aligns with the output type. + def __call__(self, image: Union[str, ImageType]) -> Dict: + """Invoke the Image captioning model. + + Parameters: + image: the input image. + + Returns: + A dictionary containing the key 'count' and the count as value. E.g. {count: 12} + """ + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "tool": "zero_shot_counting", + } + return _send_inference_request(data, "tools") + + +class VisualPromptCounting(Tool): + r"""VisualPromptCounting is a tool that can count total number of instances of an object + present in an image belonging to same class with help of an visual prompt which is a bounding box. + + Example + ------- + >>> import vision_agent as va + >>> prompt_count = va.tools.VisualPromptCounting() + >>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42") + {'count': 23} + """ + + name = "visual_prompt_counting_" + description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box." + + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + {"name": "prompt", "type": "str"}, + ], + "examples": [ + { + "scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg", + "parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"}, + }, + { + "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg", + "parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"}, + }, + { + "scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg", + "parameters": { + "image": "shirts.jpg", + "prompt": "0.1, 0.15, 0.2, 0.2", + }, + }, + { + "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg", + "parameters": { + "image": "shoes.jpg", + "prompt": "0.1, 0.1, 0.6, 0.65", + }, + }, + ], + } + + # TODO: Add support for input multiple images, which aligns with the output type. + def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: + """Invoke the Image captioning model. + + Parameters: + image: the input image. + + Returns: + A dictionary containing the key 'count' and the count as value. E.g. {count: 12} + """ + image_size = get_image_size(image) + bbox = [float(x) for x in prompt.split(",")] + prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size))) + image_b64 = convert_to_b64(image) + + data = { + "image": image_b64, + "prompt": prompt, + "tool": "few_shot_counting", + } + return _send_inference_request(data, "tools") + + class Crop(Tool): r"""Crop crops an image given a bounding box and returns a file name of the cropped image.""" @@ -798,6 +899,8 @@ def __call__(self, equation: str) -> float: ImageCaption, GroundingDINO, AgentGroundingSAM, + ZeroShotCounting, + VisualPromptCounting, AgentDINOv, ExtractFrames, Crop,