From 60576828943e6cacb0af7e96435352f626233be3 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Mon, 22 Apr 2024 13:35:26 -0700 Subject: [PATCH] adding the counting tool to take both absolute coordinate and normalized coordinates, refactoring code, adding llm generate counter tool --- README.md | 5 +++- vision_agent/image_utils.py | 51 ++++++++++++++++++++++++++++++++++- vision_agent/llm/llm.py | 4 +++ vision_agent/lmm/lmm.py | 5 ++++ vision_agent/tools/tools.py | 53 ++++++++++++------------------------- 5 files changed, 80 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 741b8ff2..c09e5823 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Vision Agent is a library that helps you utilize agent frameworks for your vision tasks. Many current vision problems can easily take hours or days to solve, you need to find the -right model, figure out how to use it, possibly write programming logic around it to +right model, figure out how to use it, possibly write programming logic around it to accomplish the task you want or even more expensive, train your own model. Vision Agent aims to provide an in-seconds experience by allowing users to describe their problem in text and utilizing agent frameworks to solve the task for them. Check out our discord @@ -108,6 +108,9 @@ you. For example: | BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. | | SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. | | ExtractFrames | ExtractFrames extracts frames with motion from a video. | +| ExtractFrames | ExtractFrames extracts frames with motion from a video. | +| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image | +| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt | It also has a basic set of calculate tools such as add, subtract, multiply and divide. diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index fefa6b13..43da645f 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -4,7 +4,7 @@ from importlib import resources from io import BytesIO from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, List import numpy as np from PIL import Image, ImageDraw, ImageFont @@ -34,6 +34,35 @@ ] +def normalize_bbox( + bbox: List[Union[int, float]], image_size: Tuple[int, ...] +) -> List[float]: + r"""Normalize the bounding box coordinates to be between 0 and 1.""" + x1, y1, x2, y2 = bbox + x1 = round(x1 / image_size[1], 2) + y1 = round(y1 / image_size[0], 2) + x2 = round(x2 / image_size[1], 2) + y2 = round(y2 / image_size[0], 2) + return [x1, y1, x2, y2] + + +def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: + r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background. + + Parameters: + mask_rle: Run-length as string formated (start length) + shape: The (height, width) of array to return + """ + s = mask_rle.split() + starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + starts -= 1 + ends = starts + lengths + img = np.zeros(shape[0] * shape[1], dtype=np.uint8) + for lo, hi in zip(starts, ends): + img[lo:hi] = 1 + return img.reshape(shape) + + def b64_to_pil(b64_str: str) -> ImageType: r"""Convert a base64 string to a PIL Image. @@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: return base64.b64encode(arr_bytes).decode("utf-8") +def denormalize_bbox( + bbox: List[Union[int, float]], image_size: Tuple[int, ...] +) -> List[float]: + r"""DeNormalize the bounding box coordinates so that they are in absolute values.""" + + if len(bbox) != 4: + raise ValueError("Bounding box must be of length 4.") + + arr = np.array(bbox) + if np.all((arr >= 0) & (arr <= 1)): + x1, y1, x2, y2 = bbox + x1 = round(x1 * image_size[1]) + y1 = round(y1 * image_size[0]) + x2 = round(x2 * image_size[1]) + y2 = round(y2 * image_size[0]) + return [x1, y1, x2, y2] + else: + return bbox + + def overlay_bboxes( image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict ) -> ImageType: diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 9352f58b..3f83c269 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -11,6 +11,7 @@ SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ZeroShotCounting, ) @@ -127,6 +128,9 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + def generate_zero_shot_counter(self, question: str) -> Callable: + return lambda x: ZeroShotCounting()(**{"image": x}) + class AzureOpenAILLM(OpenAILLM): def __init__( diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 615804ed..3d696cc3 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -15,6 +15,8 @@ SYSTEM_PROMPT, GroundingDINO, GroundingSAM, + ZeroShotCounting, + VisualPromptCounting, ) _LOGGER = logging.getLogger(__name__) @@ -272,6 +274,9 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + def generate_zero_shot_counter(self, question: str) -> Callable: + return lambda x: ZeroShotCounting()(**{"image": x}) + class AzureOpenAILMM(OpenAILMM): def __init__( diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index e4d8ada7..22a77443 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -9,7 +9,13 @@ from PIL import Image from PIL.Image import Image as ImageType -from vision_agent.image_utils import convert_to_b64, get_image_size +from vision_agent.image_utils import ( + convert_to_b64, + get_image_size, + rle_decode, + normalize_bbox, + denormalize_bbox, +) from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey @@ -18,35 +24,6 @@ _LND_API_URL = "https://api.dev.landing.ai/v1/agent" -def normalize_bbox( - bbox: List[Union[int, float]], image_size: Tuple[int, ...] -) -> List[float]: - r"""Normalize the bounding box coordinates to be between 0 and 1.""" - x1, y1, x2, y2 = bbox - x1 = round(x1 / image_size[1], 2) - y1 = round(y1 / image_size[0], 2) - x2 = round(x2 / image_size[1], 2) - y2 = round(y2 / image_size[0], 2) - return [x1, y1, x2, y2] - - -def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: - r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background. - - Parameters: - mask_rle: Run-length as string formated (start length) - shape: The (height, width) of array to return - """ - s = mask_rle.split() - starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] - starts -= 1 - ends = starts + lengths - img = np.zeros(shape[0] * shape[1], dtype=np.uint8) - for lo, hi in zip(starts, ends): - img[lo:hi] = 1 - return img.reshape(shape) - - class Tool(ABC): name: str description: str @@ -556,7 +533,7 @@ class VisualPromptCounting(Tool): ------- >>> import vision_agent as va >>> prompt_count = va.tools.VisualPromptCounting() - >>> prompt_count(image="image1.jpg", prompt="100, 100, 200, 250") + >>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42") {'count': 23} """ @@ -570,25 +547,25 @@ class VisualPromptCounting(Tool): ], "examples": [ { - "scenario": "Here is an example of a lid '200, 200, 250, 300', Can you count the lids in the image ? Image name: lids.jpg", - "parameters": {"image": "lids.jpg", "prompt": "200, 200, 250, 300"}, + "scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg", + "parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"}, }, { "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg", - "parameters": {"image": "tray.jpg", "prompt": "100, 100, 200, 250"}, + "parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"}, }, { "scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg", "parameters": { "image": "shirts.jpg", - "prompt": "100, 100, 200, 250", + "prompt": "0.1, 0.15, 0.2, 0.2", }, }, { "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg", "parameters": { "image": "shoes.jpg", - "prompt": "150, 100, 500, 550", + "prompt": "0.1, 0.1, 0.6, 0.65", }, }, ], @@ -604,7 +581,11 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: Returns: A dictionary containing the key 'count' and the count as value. E.g. {count: 12} """ + image_size = get_image_size(image) + bbox = [float(x) for x in prompt.split(",")] + prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size))) image_b64 = convert_to_b64(image) + data = { "image": image_b64, "prompt": prompt,