Skip to content

Commit

Permalink
adding object distance tool
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Apr 29, 2024
1 parent cd9932c commit c9d0311
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 5 deletions.
31 changes: 30 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import tempfile
from pathlib import Path

import numpy as np
import pytest
from PIL import Image

from vision_agent.tools import TOOLS, Tool, register_tool
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU, MaskDistance


def test_bbox_iou():
Expand Down Expand Up @@ -69,6 +70,34 @@ def test_box_distance():
assert box_dist(box1, box2) == 0.0


def test_mask_distance():
# Create two binary masks
mask1 = np.zeros((100, 100))
mask1[:10, :10] = 1 # Top left
mask2 = np.zeros((100, 100))
mask2[-10:, -10:] = 1 # Bottom right

# Save the masks as image files

with tempfile.TemporaryDirectory() as tmpdir:
mask1_path = os.path.join(tmpdir, "mask1.png")
mask2_path = os.path.join(tmpdir, "mask2.png")
Image.fromarray((mask1 * 255).astype(np.uint8)).save(mask1_path)
Image.fromarray((mask2 * 255).astype(np.uint8)).save(mask2_path)

# Calculate the distance between the masks
tool = MaskDistance()
distance = tool(mask1_path, mask2_path)
print(f"Distance between the masks: {distance}")

# Check the result
assert np.isclose(
distance,
np.sqrt(2) * 81,
atol=1e-2,
), f"Expected {np.sqrt(2) * 81}, got {distance}"


def test_register_tool():
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"

Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TOOLS,
BboxArea,
BboxIoU,
ObjectDistance,
BoxDistance,
MaskDistance,
Crop,
Expand Down
48 changes: 44 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,46 @@ def __call__(
}


class ObjectDistance(Tool):
name = "object_distance_"
description = "'object_distance_' calculates the distance between two objects in an image. It returns the minimum distance between the two objects."
usage = {
"required_parameters": [
{"name": "object1", "type": "Dict[str, Any]"},
{"name": "object2", "type": "Dict[str, Any]"},
],
"examples": [
{
"scenario": "Calculate the distance between these two objects {bboxes: [0.2, 0.21, 0.34, 0.42], masks: 'mask_file1.png'}, {bboxes: [0.3, 0.31, 0.44, 0.52], masks: 'mask_file2.png'}",
"parameters": {
"object1": {
"bboxes": [0.2, 0.21, 0.34, 0.42],
"scores": 0.54,
"masks": "mask_file1.png",
},
"object2": {
"bboxes": [0.3, 0.31, 0.44, 0.52],
"scores": 0.66,
"masks": "mask_file2.png",
},
},
}
],
}

def __call__(self, object1: Dict[str, Any], object2: Dict[str, Any]) -> float:
if "masks" in object1 and "masks" in object2:
mask1 = object1["masks"]
mask2 = object2["masks"]
return MaskDistance()(mask1, mask2)
elif "bboxes" in object1 and "bboxes" in object2:
bbox1 = object1["bboxes"]
bbox2 = object2["bboxes"]
return BoxDistance()(bbox1, bbox2)
else:
raise ValueError("Either of the objects should have masks or bboxes")


class BoxDistance(Tool):
name = "box_distance_"
description = "'box_distance_' calculates distance between two bounding boxes. It returns the minumum distance between the given bounding boxes"
Expand All @@ -966,7 +1006,7 @@ class BoxDistance(Tool):
],
"examples": [
{
"scenario": "Calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
"scenario": "Calculate the distance between these two bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
"parameters": {
"bbox1": [0.2, 0.21, 0.34, 0.42],
"bbox2": [0.3, 0.31, 0.44, 0.52],
Expand Down Expand Up @@ -1006,6 +1046,7 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
pil_mask2 = Image.open(str(mask2))
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)

mask1_points = np.transpose(np.nonzero(np_mask1))
mask2_points = np.transpose(np.nonzero(np_mask2))
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
Expand Down Expand Up @@ -1146,10 +1187,9 @@ def __call__(self, equation: str) -> float:
Crop,
BboxArea,
SegArea,
SegIoU,
MaskDistance,
ObjectDistance,
BboxContains,
BoxDistance,
SegIoU,
OCR,
Calculator,
]
Expand Down

0 comments on commit c9d0311

Please sign in to comment.