diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 00f85072..c5d7c6cc 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from vision_agent.tools.tools import BboxIoU, SegIoU +from vision_agent.tools.tools import BboxIoU, SegArea, SegIoU def test_bbox_iou(): @@ -24,3 +24,21 @@ def test_seg_iou(): Image.fromarray(mask1).save(mask1_path) Image.fromarray(mask2).save(mask2_path) assert SegIoU()(mask1_path, mask2_path) == 0.14 + + +def test_seg_area_1(): + mask = np.zeros((10, 10), dtype=np.uint8) + mask[2:4, 2:4] = 255 + with tempfile.TemporaryDirectory() as tmpdir: + mask_path = os.path.join(tmpdir, "mask.png") + Image.fromarray(mask).save(mask_path) + assert SegArea()(mask_path) == 4.0 + + +def test_seg_area_2(): + mask = np.zeros((10, 10), dtype=np.uint8) + mask[2:4, 2:4] = 1 + with tempfile.TemporaryDirectory() as tmpdir: + mask_path = os.path.join(tmpdir, "mask.png") + Image.fromarray(mask).save(mask_path) + assert SegArea()(mask_path) == 4.0 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 151faa36..74d56c0a 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -445,7 +445,8 @@ class SegArea(Tool): def __call__(self, masks: Union[str, Path]) -> float: pil_mask = Image.open(str(masks)) np_mask = np.array(pil_mask) - return cast(float, round(np.sum(np_mask) / 255, 2)) + np_mask = np.clip(np_mask, 0, 1) + return cast(float, round(np.sum(np_mask), 2)) class BboxIoU(Tool):