Skip to content

Commit

Permalink
added seg distance and fixed parameter for visual prompt counting
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Apr 27, 2024
1 parent 5ab1f2b commit 6b3fc71
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ moviepy = "1.*"
opencv-python-headless = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
scipy = "1.13.*"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BboxArea,
BboxIoU,
BoxDistance,
MaskDistance,
Crop,
DINOv,
ExtractFrames,
Expand Down
64 changes: 50 additions & 14 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
from PIL import Image
from PIL.Image import Image as ImageType
from scipy.spatial import distance

from vision_agent.image_utils import (
b64_to_pil,
Expand Down Expand Up @@ -544,7 +545,7 @@ class VisualPromptCounting(Tool):
-------
>>> import vision_agent as va
>>> prompt_count = va.tools.VisualPromptCounting()
>>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
>>> prompt_count(image="image1.jpg", prompt={"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 23}
"""

Expand All @@ -554,46 +555,54 @@ class VisualPromptCounting(Tool):
usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
{"name": "prompt", "type": "Dict[str, List[float]"},
],
"examples": [
{
"scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the items in the image ? Image name: lids.jpg",
"parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
"parameters": {
"image": "lids.jpg",
"prompt": {"bbox": [0.1, 0.1, 0.14, 0.2]},
},
},
{
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
"parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg, reference_data: {'bbox': [0.1, 0.1, 0.2, 0.25]}",
"parameters": {
"image": "tray.jpg",
"prompt": {"bbox": [0.1, 0.1, 0.2, 0.25]},
},
},
{
"scenario": "Can you count this item based on an example, reference_data: '0.1, 0.15, 0.2, 0.2' ? Image name: shirts.jpg",
"scenario": "Can you count this item based on an example, reference_data: {'bbox': [100, 115, 200, 200]} ? Image name: shirts.jpg",
"parameters": {
"image": "shirts.jpg",
"prompt": "0.1, 0.15, 0.2, 0.2",
"prompt": {"bbox": [100, 115, 200, 200]},
},
},
{
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg, reference_data: {'bbox': [0.1, 0.1, 0.6, 0.65]}",
"parameters": {
"image": "shoes.jpg",
"prompt": "0.1, 0.1, 0.6, 0.65",
"prompt": {"bbox": [0.1, 0.1, 0.6, 0.65]},
},
},
],
}

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
def __call__(
self, image: Union[str, ImageType], prompt: Dict[str, List[float]]
) -> Dict:
"""Invoke the few shot counting model.
Parameters:
image: the input image.
prompt: the visual prompt which is a bounding box describing the object.
Returns:
A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
"""
image_size = get_image_size(image)
bbox = [float(x) for x in prompt.split(",")]
bbox = prompt["bbox"]
prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

Expand Down Expand Up @@ -878,7 +887,7 @@ class SegIoU(Tool):
],
"examples": [
{
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"scenario": "Calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
}
],
Expand Down Expand Up @@ -976,6 +985,33 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2))


class MaskDistance(Tool):
name = "mask_distance_"
description = "'mask_distance_' calculates distance between two masks. It is helpful in checking proximity of two objects. It returns the minumum distance between the given masks"
usage = {
"required_parameters": [
{"name": "mask1", "type": "str"},
{"name": "mask2", "type": "str"},
],
"examples": [
{
"scenario": "Calculate the distance between the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
}
],
}

def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
pil_mask1 = Image.open(str(mask1))
pil_mask2 = Image.open(str(mask2))
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
mask1_points = np.transpose(np.nonzero(np_mask1))
mask2_points = np.transpose(np.nonzero(np_mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
return cast(float, np.round(np.min(dist_matrix), 2))


class ExtractFrames(Tool):
r"""Extract frames from a video."""

Expand Down Expand Up @@ -1110,8 +1146,8 @@ def __call__(self, equation: str) -> float:
Crop,
BboxArea,
SegArea,
BboxIoU,
SegIoU,
MaskDistance,
BboxContains,
BoxDistance,
OCR,
Expand Down

0 comments on commit 6b3fc71

Please sign in to comment.