Skip to content

Commit

Permalink
added box distance
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 17, 2024
1 parent 7ed32af commit be6d1ce
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
25 changes: 24 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image

from vision_agent.tools.tools import BboxIoU, SegArea, SegIoU
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU


def test_bbox_iou():
Expand Down Expand Up @@ -42,3 +42,26 @@ def test_seg_area_2():
mask_path = os.path.join(tmpdir, "mask.png")
Image.fromarray(mask).save(mask_path)
assert SegArea()(mask_path) == 4.0


def test_box_distance():
box_dist = BoxDistance()
# horizontal dist
box1 = [0, 0, 2, 2]
box2 = [4, 1, 6, 3]
assert box_dist(box1, box2) == 2.0

# vertical dist
box1 = [0, 0, 2, 2]
box2 = [1, 4, 3, 6]
assert box_dist(box1, box2) == 2.0

# vertical and horizontal
box1 = [0, 0, 2, 2]
box2 = [3, 3, 5, 5]
assert box_dist(box1, box2) == 1.41

# overlap
box1 = [0, 0, 2, 2]
box2 = [1, 1, 3, 3]
assert box_dist(box1, box2) == 0.0
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
TOOLS,
BboxArea,
BboxIoU,
BoxDistance,
Crop,
ExtractFrames,
GroundingDINO,
Expand Down
15 changes: 12 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,9 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:

class BoxDistance(Tool):
name = "box_distance_"
description = "'box_distance_' returns the distance between two bounding boxes."
description = (
"'box_distance_' returns the minimum distance between two bounding boxes."
)
usage = {
"required_parameters": [
{"name": "bbox1", "type": "List[int]"},
Expand All @@ -564,8 +566,14 @@ class BoxDistance(Tool):
],
}

def __call__(self, box1: List[int], box2: List[int]) -> float:
raise NotImplementedError("Not implemented yet.")
def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
x11, y11, x12, y12 = bbox1
x21, y21, x22, y22 = bbox2

horizontal_dist = np.max([0, x21 - x12, x11 - x22])
vertical_dist = np.max([0, y21 - y12, y11 - y22])

return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2))


class ExtractFrames(Tool):
Expand Down Expand Up @@ -650,6 +658,7 @@ def __call__(self, equation: str) -> float:
SegArea,
BboxIoU,
SegIoU,
BoxDistance,
Calculator,
]
)
Expand Down

0 comments on commit be6d1ce

Please sign in to comment.