Skip to content

Commit

Permalink
adding the counting tool to take both absolute coordinate and normali…
Browse files Browse the repository at this point in the history
…zed coordinates, refactoring code, adding llm generate counter tool
  • Loading branch information
shankar-vision-eng committed Apr 22, 2024
1 parent dd198bc commit 6057682
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 38 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
Many current vision problems can easily take hours or days to solve, you need to find the
right model, figure out how to use it, possibly write programming logic around it to
right model, figure out how to use it, possibly write programming logic around it to
accomplish the task you want or even more expensive, train your own model. Vision Agent
aims to provide an in-seconds experience by allowing users to describe their problem in
text and utilizing agent frameworks to solve the task for them. Check out our discord
Expand Down Expand Up @@ -108,6 +108,9 @@ you. For example:
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
Expand Down
51 changes: 50 additions & 1 deletion vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from importlib import resources
from io import BytesIO
from pathlib import Path
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, List

import numpy as np
from PIL import Image, ImageDraw, ImageFont
Expand Down Expand Up @@ -34,6 +34,35 @@
]


def normalize_bbox(
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
r"""Normalize the bounding box coordinates to be between 0 and 1."""
x1, y1, x2, y2 = bbox
x1 = round(x1 / image_size[1], 2)
y1 = round(y1 / image_size[0], 2)
x2 = round(x2 / image_size[1], 2)
y2 = round(y2 / image_size[0], 2)
return [x1, y1, x2, y2]


def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
Parameters:
mask_rle: Run-length as string formated (start length)
shape: The (height, width) of array to return
"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape)


def b64_to_pil(b64_str: str) -> ImageType:
r"""Convert a base64 string to a PIL Image.
Expand Down Expand Up @@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
return base64.b64encode(arr_bytes).decode("utf-8")


def denormalize_bbox(
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
r"""DeNormalize the bounding box coordinates so that they are in absolute values."""

if len(bbox) != 4:
raise ValueError("Bounding box must be of length 4.")

arr = np.array(bbox)
if np.all((arr >= 0) & (arr <= 1)):
x1, y1, x2, y2 = bbox
x1 = round(x1 * image_size[1])
y1 = round(y1 * image_size[0])
x2 = round(x2 * image_size[1])
y2 = round(y2 * image_size[0])
return [x1, y1, x2, y2]
else:
return bbox


def overlay_bboxes(
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
) -> ImageType:
Expand Down
4 changes: 4 additions & 0 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
)


Expand Down Expand Up @@ -127,6 +128,9 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})


class AzureOpenAILLM(OpenAILLM):
def __init__(
Expand Down
5 changes: 5 additions & 0 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
VisualPromptCounting,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -272,6 +274,9 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})


class AzureOpenAILMM(OpenAILMM):
def __init__(
Expand Down
53 changes: 17 additions & 36 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from PIL import Image
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64, get_image_size
from vision_agent.image_utils import (
convert_to_b64,
get_image_size,
rle_decode,
normalize_bbox,
denormalize_bbox,
)
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey

Expand All @@ -18,35 +24,6 @@
_LND_API_URL = "https://api.dev.landing.ai/v1/agent"


def normalize_bbox(
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
r"""Normalize the bounding box coordinates to be between 0 and 1."""
x1, y1, x2, y2 = bbox
x1 = round(x1 / image_size[1], 2)
y1 = round(y1 / image_size[0], 2)
x2 = round(x2 / image_size[1], 2)
y2 = round(y2 / image_size[0], 2)
return [x1, y1, x2, y2]


def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
Parameters:
mask_rle: Run-length as string formated (start length)
shape: The (height, width) of array to return
"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape)


class Tool(ABC):
name: str
description: str
Expand Down Expand Up @@ -556,7 +533,7 @@ class VisualPromptCounting(Tool):
-------
>>> import vision_agent as va
>>> prompt_count = va.tools.VisualPromptCounting()
>>> prompt_count(image="image1.jpg", prompt="100, 100, 200, 250")
>>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
{'count': 23}
"""

Expand All @@ -570,25 +547,25 @@ class VisualPromptCounting(Tool):
],
"examples": [
{
"scenario": "Here is an example of a lid '200, 200, 250, 300', Can you count the lids in the image ? Image name: lids.jpg",
"parameters": {"image": "lids.jpg", "prompt": "200, 200, 250, 300"},
"scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the lids in the image ? Image name: lids.jpg",
"parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
},
{
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
"parameters": {"image": "tray.jpg", "prompt": "100, 100, 200, 250"},
"parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
},
{
"scenario": "Can you build me a few shot object counting tool ? Image name: shirts.jpg",
"parameters": {
"image": "shirts.jpg",
"prompt": "100, 100, 200, 250",
"prompt": "0.1, 0.15, 0.2, 0.2",
},
},
{
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
"parameters": {
"image": "shoes.jpg",
"prompt": "150, 100, 500, 550",
"prompt": "0.1, 0.1, 0.6, 0.65",
},
},
],
Expand All @@ -604,7 +581,11 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
Returns:
A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
"""
image_size = get_image_size(image)
bbox = [float(x) for x in prompt.split(",")]
prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": prompt,
Expand Down

0 comments on commit 6057682

Please sign in to comment.