Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added seg distance and fixed parameter for visual prompt counting #70

Merged
merged 5 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ moviepy = "1.*"
opencv-python-headless = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
scipy = "1.13.*"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
BboxArea,
BboxIoU,
BoxDistance,
MaskDistance,
Crop,
DINOv,
ExtractFrames,
Expand Down
68 changes: 52 additions & 16 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
from PIL import Image
from PIL.Image import Image as ImageType
from scipy.spatial import distance # type: ignore

from vision_agent.image_utils import (
b64_to_pil,
Expand Down Expand Up @@ -544,7 +545,7 @@ class VisualPromptCounting(Tool):
-------
>>> import vision_agent as va
>>> prompt_count = va.tools.VisualPromptCounting()
>>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42")
>>> prompt_count(image="image1.jpg", prompt={"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 23}
"""

Expand All @@ -554,52 +555,60 @@ class VisualPromptCounting(Tool):
usage = {
"required_parameters": [
{"name": "image", "type": "str"},
{"name": "prompt", "type": "str"},
{"name": "prompt", "type": "Dict[str, List[float]"},
],
"examples": [
{
"scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the items in the image ? Image name: lids.jpg",
"parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"},
"parameters": {
"image": "lids.jpg",
"prompt": {"bbox": [0.1, 0.1, 0.14, 0.2]},
},
},
{
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg",
"parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"},
"scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg, reference_data: {'bbox': [0.1, 0.1, 0.2, 0.25]}",
"parameters": {
"image": "tray.jpg",
"prompt": {"bbox": [0.1, 0.1, 0.2, 0.25]},
},
},
{
"scenario": "Can you count this item based on an example, reference_data: '0.1, 0.15, 0.2, 0.2' ? Image name: shirts.jpg",
"scenario": "Can you count this item based on an example, reference_data: {'bbox': [100, 115, 200, 200]} ? Image name: shirts.jpg",
"parameters": {
"image": "shirts.jpg",
"prompt": "0.1, 0.15, 0.2, 0.2",
"prompt": {"bbox": [100, 115, 200, 200]},
},
},
{
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg",
"scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg, reference_data: {'bbox': [0.1, 0.1, 0.6, 0.65]}",
"parameters": {
"image": "shoes.jpg",
"prompt": "0.1, 0.1, 0.6, 0.65",
"prompt": {"bbox": [0.1, 0.1, 0.6, 0.65]},
},
},
],
}

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
def __call__(
self, image: Union[str, ImageType], prompt: Dict[str, List[float]]
) -> Dict:
"""Invoke the few shot counting model.

Parameters:
image: the input image.
prompt: the visual prompt which is a bounding box describing the object.

Returns:
A dictionary containing the key 'count' and the count as value. E.g. {count: 12}
"""
image_size = get_image_size(image)
bbox = [float(x) for x in prompt.split(",")]
prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
bbox = prompt["bbox"]
bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": prompt,
"prompt": bbox_str,
"tool": "few_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
Expand Down Expand Up @@ -878,7 +887,7 @@ class SegIoU(Tool):
],
"examples": [
{
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"scenario": "Calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
}
],
Expand Down Expand Up @@ -976,6 +985,33 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2))


class MaskDistance(Tool):
name = "mask_distance_"
description = "'mask_distance_' calculates distance between two masks. It is helpful in checking proximity of two objects. It returns the minumum distance between the given masks"
usage = {
"required_parameters": [
{"name": "mask1", "type": "str"},
{"name": "mask2", "type": "str"},
],
"examples": [
{
"scenario": "Calculate the distance between the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
}
],
}

def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
pil_mask1 = Image.open(str(mask1))
pil_mask2 = Image.open(str(mask2))
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
mask1_points = np.transpose(np.nonzero(np_mask1))
mask2_points = np.transpose(np.nonzero(np_mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
return cast(float, np.round(np.min(dist_matrix), 2))


class ExtractFrames(Tool):
r"""Extract frames from a video."""

Expand Down Expand Up @@ -1110,8 +1146,8 @@ def __call__(self, equation: str) -> float:
Crop,
BboxArea,
SegArea,
BboxIoU,
SegIoU,
MaskDistance,
BboxContains,
BoxDistance,
OCR,
Expand Down
Loading