diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 12c21347..6de8d6c8 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -2,8 +2,10 @@ import tempfile import numpy as np +import pytest from PIL import Image +from vision_agent.tools import TOOLS, Tool, register_tool from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU @@ -65,3 +67,71 @@ def test_box_distance(): box1 = [0, 0, 2, 2] box2 = [1, 1, 3, 3] assert box_dist(box1, box2) == 0.0 + + +def test_register_tool(): + assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_" + + @register_tool + class TestTool(Tool): + name = "test_tool_" + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + def __call__(self, prompt: str) -> str: + return prompt + + assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_" + + +def test_register_tool_incorrect(): + with pytest.raises(ValueError): + + @register_tool + class NoAttributes(Tool): + pass + + with pytest.raises(ValueError): + + @register_tool + class NoName(Tool): + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoDescription(Tool): + name = "test_tool_" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoUsage(Tool): + name = "test_tool_" + description = "Test Tool" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 1c1c6e73..bf6bf70d 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -14,4 +14,5 @@ SegArea, SegIoU, Tool, + register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6d2a7b47..711fc894 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -2,7 +2,7 @@ import tempfile from abc import ABC from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, Type, Union, cast import numpy as np import requests @@ -765,6 +765,30 @@ def __call__(self, equation: str) -> float: } +def register_tool(tool: Type[Tool]) -> None: + r"""Add a tool to the list of available tools. + + Parameters: + tool: The tool to add. + """ + + if ( + not hasattr(tool, "name") + or not hasattr(tool, "description") + or not hasattr(tool, "usage") + ): + raise ValueError( + "The tool must have 'name', 'description' and 'usage' attributes." + ) + + TOOLS[len(TOOLS)] = { + "name": tool.name, + "description": tool.description, + "usage": tool.usage, + "class": tool, + } + + def _send_inference_request( payload: Dict[str, Any], endpoint_name: str ) -> Dict[str, Any]: