diff --git a/poetry.lock b/poetry.lock index 04733555..c74b4af9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1473,7 +1473,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1715,45 +1714,45 @@ test = ["asv", "matplotlib (>=3.5)", "numpydoc (>=1.5)", "pooch (>=1.6.0)", "pyt [[package]] name = "scipy" -version = "1.12.0" +version = "1.13.0" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.9" files = [ - {file = "scipy-1.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:78e4402e140879387187f7f25d91cc592b3501a2e51dfb320f48dfb73565f10b"}, - {file = "scipy-1.12.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:f5f00ebaf8de24d14b8449981a2842d404152774c1a1d880c901bf454cb8e2a1"}, - {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e53958531a7c695ff66c2e7bb7b79560ffdc562e2051644c5576c39ff8efb563"}, - {file = "scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e32847e08da8d895ce09d108a494d9eb78974cf6de23063f93306a3e419960c"}, - {file = "scipy-1.12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4c1020cad92772bf44b8e4cdabc1df5d87376cb219742549ef69fc9fd86282dd"}, - {file = "scipy-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:75ea2a144096b5e39402e2ff53a36fecfd3b960d786b7efd3c180e29c39e53f2"}, - {file = "scipy-1.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:408c68423f9de16cb9e602528be4ce0d6312b05001f3de61fe9ec8b1263cad08"}, - {file = "scipy-1.12.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5adfad5dbf0163397beb4aca679187d24aec085343755fcdbdeb32b3679f254c"}, - {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3003652496f6e7c387b1cf63f4bb720951cfa18907e998ea551e6de51a04467"}, - {file = "scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b8066bce124ee5531d12a74b617d9ac0ea59245246410e19bca549656d9a40a"}, - {file = "scipy-1.12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8bee4993817e204d761dba10dbab0774ba5a8612e57e81319ea04d84945375ba"}, - {file = "scipy-1.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a24024d45ce9a675c1fb8494e8e5244efea1c7a09c60beb1eeb80373d0fecc70"}, - {file = "scipy-1.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e7e76cc48638228212c747ada851ef355c2bb5e7f939e10952bc504c11f4e372"}, - {file = "scipy-1.12.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f7ce148dffcd64ade37b2df9315541f9adad6efcaa86866ee7dd5db0c8f041c3"}, - {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c39f92041f490422924dfdb782527a4abddf4707616e07b021de33467f917bc"}, - {file = "scipy-1.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a7ebda398f86e56178c2fa94cad15bf457a218a54a35c2a7b4490b9f9cb2676c"}, - {file = "scipy-1.12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:95e5c750d55cf518c398a8240571b0e0782c2d5a703250872f36eaf737751338"}, - {file = "scipy-1.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:e646d8571804a304e1da01040d21577685ce8e2db08ac58e543eaca063453e1c"}, - {file = "scipy-1.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:913d6e7956c3a671de3b05ccb66b11bc293f56bfdef040583a7221d9e22a2e35"}, - {file = "scipy-1.12.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba1b0c7256ad75401c73e4b3cf09d1f176e9bd4248f0d3112170fb2ec4db067"}, - {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:730badef9b827b368f351eacae2e82da414e13cf8bd5051b4bdfd720271a5371"}, - {file = "scipy-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6546dc2c11a9df6926afcbdd8a3edec28566e4e785b915e849348c6dd9f3f490"}, - {file = "scipy-1.12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:196ebad3a4882081f62a5bf4aeb7326aa34b110e533aab23e4374fcccb0890dc"}, - {file = "scipy-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:b360f1b6b2f742781299514e99ff560d1fe9bd1bff2712894b52abe528d1fd1e"}, - {file = "scipy-1.12.0.tar.gz", hash = "sha256:4bf5abab8a36d20193c698b0f1fc282c1d083c94723902c447e5d2f1780936a3"}, + {file = "scipy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba419578ab343a4e0a77c0ef82f088238a93eef141b2b8017e46149776dfad4d"}, + {file = "scipy-1.13.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:22789b56a999265431c417d462e5b7f2b487e831ca7bef5edeb56efe4c93f86e"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05f1432ba070e90d42d7fd836462c50bf98bd08bed0aa616c359eed8a04e3922"}, + {file = "scipy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8434f6f3fa49f631fae84afee424e2483289dfc30a47755b4b4e6b07b2633a4"}, + {file = "scipy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dcbb9ea49b0167de4167c40eeee6e167caeef11effb0670b554d10b1e693a8b9"}, + {file = "scipy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:1d2f7bb14c178f8b13ebae93f67e42b0a6b0fc50eba1cd8021c9b6e08e8fb1cd"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fbcf8abaf5aa2dc8d6400566c1a727aed338b5fe880cde64907596a89d576fa"}, + {file = "scipy-1.13.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5e4a756355522eb60fcd61f8372ac2549073c8788f6114449b37e9e8104f15a5"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5acd8e1dbd8dbe38d0004b1497019b2dbbc3d70691e65d69615f8a7292865d7"}, + {file = "scipy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ff7dad5d24a8045d836671e082a490848e8639cabb3dbdacb29f943a678683d"}, + {file = "scipy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4dca18c3ffee287ddd3bc8f1dabaf45f5305c5afc9f8ab9cbfab855e70b2df5c"}, + {file = "scipy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:a2f471de4d01200718b2b8927f7d76b5d9bde18047ea0fa8bd15c5ba3f26a1d6"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0de696f589681c2802f9090fff730c218f7c51ff49bf252b6a97ec4a5d19e8b"}, + {file = "scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:b2a3ff461ec4756b7e8e42e1c681077349a038f0686132d623fa404c0bee2551"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf9fe63e7a4bf01d3645b13ff2aa6dea023d38993f42aaac81a18b1bda7a82a"}, + {file = "scipy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1e7626dfd91cdea5714f343ce1176b6c4745155d234f1033584154f60ef1ff42"}, + {file = "scipy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:109d391d720fcebf2fbe008621952b08e52907cf4c8c7efc7376822151820820"}, + {file = "scipy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:8930ae3ea371d6b91c203b1032b9600d69c568e537b7988a3073dfe4d4774f21"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5407708195cb38d70fd2d6bb04b1b9dd5c92297d86e9f9daae1576bd9e06f602"}, + {file = "scipy-1.13.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ac38c4c92951ac0f729c4c48c9e13eb3675d9986cc0c83943784d7390d540c78"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09c74543c4fbeb67af6ce457f6a6a28e5d3739a87f62412e4a16e46f164f0ae5"}, + {file = "scipy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28e286bf9ac422d6beb559bc61312c348ca9b0f0dae0d7c5afde7f722d6ea13d"}, + {file = "scipy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:33fde20efc380bd23a78a4d26d59fc8704e9b5fd9b08841693eb46716ba13d86"}, + {file = "scipy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:45c08bec71d3546d606989ba6e7daa6f0992918171e2a6f7fbedfa7361c2de1e"}, + {file = "scipy-1.13.0.tar.gz", hash = "sha256:58569af537ea29d3f78e5abd18398459f195546bb3be23d16677fb26616cc11e"}, ] [package.dependencies] -numpy = ">=1.22.4,<1.29.0" +numpy = ">=1.22.4,<2.3" [package.extras] -dev = ["click", "cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] -doc = ["jupytext", "matplotlib (>2)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (==0.9.0)", "sphinx (!=4.1.0)", "sphinx-design (>=0.2.0)"] -test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy", "pycodestyle", "pydevtool", "rich-click", "ruff", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.12.0)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0)", "sphinx-design (>=0.4.0)"] +test = ["array-api-strict", "asv", "gmpy2", "hypothesis (>=6.30)", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "setuptools" @@ -2011,4 +2010,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9" -content-hash = "5c57abb9e41b66ee6195875318971f2c2ce858879efca8f5edefb2226284ee6c" +content-hash = "044e8fb239976241edc1e8ed1c91882aa77bfc11264e867129ba3dd13ef07ef2" diff --git a/pyproject.toml b/pyproject.toml index e27a9e2a..6e77b41a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ moviepy = "1.*" opencv-python-headless = "4.*" tabulate = "^0.9.0" pydantic-settings = "^2.2.1" +scipy = "1.13.*" [tool.poetry.group.dev.dependencies] autoflake = "1.*" diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 6de8d6c8..d350ad21 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -6,7 +6,7 @@ from PIL import Image from vision_agent.tools import TOOLS, Tool, register_tool -from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU +from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU, MaskDistance def test_bbox_iou(): @@ -69,6 +69,34 @@ def test_box_distance(): assert box_dist(box1, box2) == 0.0 +def test_mask_distance(): + # Create two binary masks + mask1 = np.zeros((100, 100)) + mask1[:10, :10] = 1 # Top left + mask2 = np.zeros((100, 100)) + mask2[-10:, -10:] = 1 # Bottom right + + # Save the masks as image files + + with tempfile.TemporaryDirectory() as tmpdir: + mask1_path = os.path.join(tmpdir, "mask1.png") + mask2_path = os.path.join(tmpdir, "mask2.png") + Image.fromarray((mask1 * 255).astype(np.uint8)).save(mask1_path) + Image.fromarray((mask2 * 255).astype(np.uint8)).save(mask2_path) + + # Calculate the distance between the masks + tool = MaskDistance() + distance = tool(mask1_path, mask2_path) + print(f"Distance between the masks: {distance}") + + # Check the result + assert np.isclose( + distance, + np.sqrt(2) * 81, + atol=1e-2, + ), f"Expected {np.sqrt(2) * 81}, got {distance}" + + def test_register_tool(): assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 10daf7eb..eb6da91a 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -5,7 +5,9 @@ TOOLS, BboxArea, BboxIoU, + ObjectDistance, BoxDistance, + MaskDistance, Crop, DINOv, ExtractFrames, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index a9eca833..aec4f19b 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -9,6 +9,7 @@ import requests from PIL import Image from PIL.Image import Image as ImageType +from scipy.spatial import distance # type: ignore from vision_agent.image_utils import ( b64_to_pil, @@ -544,7 +545,7 @@ class VisualPromptCounting(Tool): ------- >>> import vision_agent as va >>> prompt_count = va.tools.VisualPromptCounting() - >>> prompt_count(image="image1.jpg", prompt="0.1, 0.1, 0.4, 0.42") + >>> prompt_count(image="image1.jpg", prompt={"bbox": [0.1, 0.1, 0.4, 0.42]}) {'count': 23} """ @@ -554,52 +555,60 @@ class VisualPromptCounting(Tool): usage = { "required_parameters": [ {"name": "image", "type": "str"}, - {"name": "prompt", "type": "str"}, + {"name": "prompt", "type": "Dict[str, List[float]"}, ], "examples": [ { "scenario": "Here is an example of a lid '0.1, 0.1, 0.14, 0.2', Can you count the items in the image ? Image name: lids.jpg", - "parameters": {"image": "lids.jpg", "prompt": "0.1, 0.1, 0.14, 0.2"}, + "parameters": { + "image": "lids.jpg", + "prompt": {"bbox": [0.1, 0.1, 0.14, 0.2]}, + }, }, { - "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg", - "parameters": {"image": "tray.jpg", "prompt": "0.1, 0.1, 0.2, 0.25"}, + "scenario": "Can you count the total number of objects in this image ? Image name: tray.jpg, reference_data: {'bbox': [0.1, 0.1, 0.2, 0.25]}", + "parameters": { + "image": "tray.jpg", + "prompt": {"bbox": [0.1, 0.1, 0.2, 0.25]}, + }, }, { - "scenario": "Can you count this item based on an example, reference_data: '0.1, 0.15, 0.2, 0.2' ? Image name: shirts.jpg", + "scenario": "Can you count this item based on an example, reference_data: {'bbox': [100, 115, 200, 200]} ? Image name: shirts.jpg", "parameters": { "image": "shirts.jpg", - "prompt": "0.1, 0.15, 0.2, 0.2", + "prompt": {"bbox": [100, 115, 200, 200]}, }, }, { - "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg", + "scenario": "Can you build me a counting tool based on an example prompt ? Image name: shoes.jpg, reference_data: {'bbox': [0.1, 0.1, 0.6, 0.65]}", "parameters": { "image": "shoes.jpg", - "prompt": "0.1, 0.1, 0.6, 0.65", + "prompt": {"bbox": [0.1, 0.1, 0.6, 0.65]}, }, }, ], } - # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: + def __call__( + self, image: Union[str, ImageType], prompt: Dict[str, List[float]] + ) -> Dict: """Invoke the few shot counting model. Parameters: image: the input image. + prompt: the visual prompt which is a bounding box describing the object. Returns: A dictionary containing the key 'count' and the count as value. E.g. {count: 12} """ image_size = get_image_size(image) - bbox = [float(x) for x in prompt.split(",")] - prompt = ", ".join(map(str, denormalize_bbox(bbox, image_size))) + bbox = prompt["bbox"] + bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size))) image_b64 = convert_to_b64(image) data = { "image": image_b64, - "prompt": prompt, + "prompt": bbox_str, "tool": "few_shot_counting", } resp_data = _send_inference_request(data, "tools") @@ -878,7 +887,7 @@ class SegIoU(Tool): ], "examples": [ { - "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg", + "scenario": "Calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg", "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"}, } ], @@ -947,6 +956,46 @@ def __call__( } +class ObjectDistance(Tool): + name = "object_distance_" + description = "'object_distance_' calculates the distance between two objects in an image. It returns the minimum distance between the two objects." + usage = { + "required_parameters": [ + {"name": "object1", "type": "Dict[str, Any]"}, + {"name": "object2", "type": "Dict[str, Any]"}, + ], + "examples": [ + { + "scenario": "Calculate the distance between these two objects {bboxes: [0.2, 0.21, 0.34, 0.42], masks: 'mask_file1.png'}, {bboxes: [0.3, 0.31, 0.44, 0.52], masks: 'mask_file2.png'}", + "parameters": { + "object1": { + "bboxes": [0.2, 0.21, 0.34, 0.42], + "scores": 0.54, + "masks": "mask_file1.png", + }, + "object2": { + "bboxes": [0.3, 0.31, 0.44, 0.52], + "scores": 0.66, + "masks": "mask_file2.png", + }, + }, + } + ], + } + + def __call__(self, object1: Dict[str, Any], object2: Dict[str, Any]) -> float: + if "masks" in object1 and "masks" in object2: + mask1 = object1["masks"] + mask2 = object2["masks"] + return MaskDistance()(mask1, mask2) + elif "bboxes" in object1 and "bboxes" in object2: + bbox1 = object1["bboxes"] + bbox2 = object2["bboxes"] + return BoxDistance()(bbox1, bbox2) + else: + raise ValueError("Either of the objects should have masks or bboxes") + + class BoxDistance(Tool): name = "box_distance_" description = "'box_distance_' calculates distance between two bounding boxes. It returns the minumum distance between the given bounding boxes" @@ -957,7 +1006,7 @@ class BoxDistance(Tool): ], "examples": [ { - "scenario": "Calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", + "scenario": "Calculate the distance between these two bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", "parameters": { "bbox1": [0.2, 0.21, 0.34, 0.42], "bbox2": [0.3, 0.31, 0.44, 0.52], @@ -976,6 +1025,34 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2)) +class MaskDistance(Tool): + name = "mask_distance_" + description = "'mask_distance_' calculates distance between two masks. It is helpful in checking proximity of two objects. It returns the minumum distance between the given masks" + usage = { + "required_parameters": [ + {"name": "mask1", "type": "str"}, + {"name": "mask2", "type": "str"}, + ], + "examples": [ + { + "scenario": "Calculate the distance between the segmentation masks for mask_file1.jpg and mask_file2.jpg", + "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"}, + } + ], + } + + def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: + pil_mask1 = Image.open(str(mask1)) + pil_mask2 = Image.open(str(mask2)) + np_mask1 = np.clip(np.array(pil_mask1), 0, 1) + np_mask2 = np.clip(np.array(pil_mask2), 0, 1) + + mask1_points = np.transpose(np.nonzero(np_mask1)) + mask2_points = np.transpose(np.nonzero(np_mask2)) + dist_matrix = distance.cdist(mask1_points, mask2_points, "euclidean") + return cast(float, np.round(np.min(dist_matrix), 2)) + + class ExtractFrames(Tool): r"""Extract frames from a video.""" @@ -1110,10 +1187,9 @@ def __call__(self, equation: str) -> float: Crop, BboxArea, SegArea, - BboxIoU, - SegIoU, + ObjectDistance, BboxContains, - BoxDistance, + SegIoU, OCR, Calculator, ]