Skip to content

Commit

Permalink
SegArea to work with all >0 values (#43)
Browse files Browse the repository at this point in the history
seg area to work with all >0 values
  • Loading branch information
dillonalaird authored Apr 9, 2024
1 parent 49af502 commit c5de3b8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
20 changes: 19 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL import Image

from vision_agent.tools.tools import BboxIoU, SegIoU
from vision_agent.tools.tools import BboxIoU, SegArea, SegIoU


def test_bbox_iou():
Expand All @@ -24,3 +24,21 @@ def test_seg_iou():
Image.fromarray(mask1).save(mask1_path)
Image.fromarray(mask2).save(mask2_path)
assert SegIoU()(mask1_path, mask2_path) == 0.14


def test_seg_area_1():
mask = np.zeros((10, 10), dtype=np.uint8)
mask[2:4, 2:4] = 255
with tempfile.TemporaryDirectory() as tmpdir:
mask_path = os.path.join(tmpdir, "mask.png")
Image.fromarray(mask).save(mask_path)
assert SegArea()(mask_path) == 4.0


def test_seg_area_2():
mask = np.zeros((10, 10), dtype=np.uint8)
mask[2:4, 2:4] = 1
with tempfile.TemporaryDirectory() as tmpdir:
mask_path = os.path.join(tmpdir, "mask.png")
Image.fromarray(mask).save(mask_path)
assert SegArea()(mask_path) == 4.0
3 changes: 2 additions & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ class SegArea(Tool):
def __call__(self, masks: Union[str, Path]) -> float:
pil_mask = Image.open(str(masks))
np_mask = np.array(pil_mask)
return cast(float, round(np.sum(np_mask) / 255, 2))
np_mask = np.clip(np_mask, 0, 1)
return cast(float, round(np.sum(np_mask), 2))


class BboxIoU(Tool):
Expand Down

0 comments on commit c5de3b8

Please sign in to comment.