Skip to content

Commit c9d0311

Browse files
adding object distance tool
1 parent cd9932c commit c9d0311

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

tests/tools/test_tools.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
22
import tempfile
3+
from pathlib import Path
34

45
import numpy as np
56
import pytest
67
from PIL import Image
78

89
from vision_agent.tools import TOOLS, Tool, register_tool
9-
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU
10+
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU, MaskDistance
1011

1112

1213
def test_bbox_iou():
@@ -69,6 +70,34 @@ def test_box_distance():
6970
assert box_dist(box1, box2) == 0.0
7071

7172

73+
def test_mask_distance():
74+
# Create two binary masks
75+
mask1 = np.zeros((100, 100))
76+
mask1[:10, :10] = 1 # Top left
77+
mask2 = np.zeros((100, 100))
78+
mask2[-10:, -10:] = 1 # Bottom right
79+
80+
# Save the masks as image files
81+
82+
with tempfile.TemporaryDirectory() as tmpdir:
83+
mask1_path = os.path.join(tmpdir, "mask1.png")
84+
mask2_path = os.path.join(tmpdir, "mask2.png")
85+
Image.fromarray((mask1 * 255).astype(np.uint8)).save(mask1_path)
86+
Image.fromarray((mask2 * 255).astype(np.uint8)).save(mask2_path)
87+
88+
# Calculate the distance between the masks
89+
tool = MaskDistance()
90+
distance = tool(mask1_path, mask2_path)
91+
print(f"Distance between the masks: {distance}")
92+
93+
# Check the result
94+
assert np.isclose(
95+
distance,
96+
np.sqrt(2) * 81,
97+
atol=1e-2,
98+
), f"Expected {np.sqrt(2) * 81}, got {distance}"
99+
100+
72101
def test_register_tool():
73102
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"
74103

vision_agent/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
TOOLS,
66
BboxArea,
77
BboxIoU,
8+
ObjectDistance,
89
BoxDistance,
910
MaskDistance,
1011
Crop,

vision_agent/tools/tools.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,46 @@ def __call__(
956956
}
957957

958958

959+
class ObjectDistance(Tool):
960+
name = "object_distance_"
961+
description = "'object_distance_' calculates the distance between two objects in an image. It returns the minimum distance between the two objects."
962+
usage = {
963+
"required_parameters": [
964+
{"name": "object1", "type": "Dict[str, Any]"},
965+
{"name": "object2", "type": "Dict[str, Any]"},
966+
],
967+
"examples": [
968+
{
969+
"scenario": "Calculate the distance between these two objects {bboxes: [0.2, 0.21, 0.34, 0.42], masks: 'mask_file1.png'}, {bboxes: [0.3, 0.31, 0.44, 0.52], masks: 'mask_file2.png'}",
970+
"parameters": {
971+
"object1": {
972+
"bboxes": [0.2, 0.21, 0.34, 0.42],
973+
"scores": 0.54,
974+
"masks": "mask_file1.png",
975+
},
976+
"object2": {
977+
"bboxes": [0.3, 0.31, 0.44, 0.52],
978+
"scores": 0.66,
979+
"masks": "mask_file2.png",
980+
},
981+
},
982+
}
983+
],
984+
}
985+
986+
def __call__(self, object1: Dict[str, Any], object2: Dict[str, Any]) -> float:
987+
if "masks" in object1 and "masks" in object2:
988+
mask1 = object1["masks"]
989+
mask2 = object2["masks"]
990+
return MaskDistance()(mask1, mask2)
991+
elif "bboxes" in object1 and "bboxes" in object2:
992+
bbox1 = object1["bboxes"]
993+
bbox2 = object2["bboxes"]
994+
return BoxDistance()(bbox1, bbox2)
995+
else:
996+
raise ValueError("Either of the objects should have masks or bboxes")
997+
998+
959999
class BoxDistance(Tool):
9601000
name = "box_distance_"
9611001
description = "'box_distance_' calculates distance between two bounding boxes. It returns the minumum distance between the given bounding boxes"
@@ -966,7 +1006,7 @@ class BoxDistance(Tool):
9661006
],
9671007
"examples": [
9681008
{
969-
"scenario": "Calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
1009+
"scenario": "Calculate the distance between these two bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
9701010
"parameters": {
9711011
"bbox1": [0.2, 0.21, 0.34, 0.42],
9721012
"bbox2": [0.3, 0.31, 0.44, 0.52],
@@ -1006,6 +1046,7 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
10061046
pil_mask2 = Image.open(str(mask2))
10071047
np_mask1 = np.clip(np.array(pil_mask1), 0, 1)
10081048
np_mask2 = np.clip(np.array(pil_mask2), 0, 1)
1049+
10091050
mask1_points = np.transpose(np.nonzero(np_mask1))
10101051
mask2_points = np.transpose(np.nonzero(np_mask2))
10111052
dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean")
@@ -1146,10 +1187,9 @@ def __call__(self, equation: str) -> float:
11461187
Crop,
11471188
BboxArea,
11481189
SegArea,
1149-
SegIoU,
1150-
MaskDistance,
1190+
ObjectDistance,
11511191
BboxContains,
1152-
BoxDistance,
1192+
SegIoU,
11531193
OCR,
11541194
Calculator,
11551195
]

0 commit comments

Comments
 (0)