Skip to content

Commit

Permalink
added seg distance and fixed parameter for visual prompt counting (#70)
Browse files Browse the repository at this point in the history
* added seg distance and fixed parameter for visual prompt counting

* updating dependency

* fix linting

* adding object distance tool

* fix linting
  • Loading branch information
shankar-vision-eng authored Apr 29, 2024
1 parent 5ab1f2b commit 305b343
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 53 deletions.
65 changes: 32 additions & 33 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ moviepy = "1.*"
opencv-python-headless = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
scipy = "1.13.*"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
30 changes: 29 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from PIL import Image

from vision_agent.tools import TOOLS, Tool, register_tool
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU, MaskDistance


def test_bbox_iou():
Expand Down Expand Up @@ -69,6 +69,34 @@ def test_box_distance():
assert box_dist(box1, box2) == 0.0


def test_mask_distance():
# Create two binary masks
mask1 = np.zeros((100, 100))
mask1[:10, :10] = 1 # Top left
mask2 = np.zeros((100, 100))
mask2[-10:, -10:] = 1 # Bottom right

# Save the masks as image files

with tempfile.TemporaryDirectory() as tmpdir:
mask1_path = os.path.join(tmpdir, "mask1.png")
mask2_path = os.path.join(tmpdir, "mask2.png")
Image.fromarray((mask1 * 255).astype(np.uint8)).save(mask1_path)
Image.fromarray((mask2 * 255).astype(np.uint8)).save(mask2_path)

# Calculate the distance between the masks
tool = MaskDistance()
distance = tool(mask1_path, mask2_path)
print(f"Distance between the masks: {distance}")

# Check the result
assert np.isclose(
distance,
np.sqrt(2) * 81,
atol=1e-2,
), f"Expected {np.sqrt(2) * 81}, got {distance}"


def test_register_tool():
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"

Expand Down
2 changes: 2 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
TOOLS,
BboxArea,
BboxIoU,
ObjectDistance,
BoxDistance,
MaskDistance,
Crop,
DINOv,
ExtractFrames,
Expand Down
Loading

0 comments on commit 305b343

Please sign in to comment.