Skip to content

Commit 03c2480

Browse files
committed
added custom tools
1 parent c505b4e commit 03c2480

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

tests/tools/test_tools.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import tempfile
33

44
import numpy as np
5+
import pytest
56
from PIL import Image
67

8+
from vision_agent.tools import TOOLS, Tool, register_tool
79
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU
810

911

@@ -65,3 +67,71 @@ def test_box_distance():
6567
box1 = [0, 0, 2, 2]
6668
box2 = [1, 1, 3, 3]
6769
assert box_dist(box1, box2) == 0.0
70+
71+
72+
def test_register_tool():
73+
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"
74+
75+
@register_tool
76+
class TestTool(Tool):
77+
name = "test_tool_"
78+
description = "Test Tool"
79+
usage = {
80+
"required_parameters": [{"name": "prompt", "type": "str"}],
81+
"examples": [
82+
{
83+
"scenario": "Test",
84+
"parameters": {"prompt": "Test Prompt"},
85+
}
86+
],
87+
}
88+
89+
def __call__(self, prompt: str) -> str:
90+
return prompt
91+
92+
assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_"
93+
94+
95+
def test_register_tool_incorrect():
96+
with pytest.raises(ValueError):
97+
98+
@register_tool
99+
class NoAttributes(Tool):
100+
pass
101+
102+
with pytest.raises(ValueError):
103+
104+
@register_tool
105+
class NoName(Tool):
106+
description = "Test Tool"
107+
usage = {
108+
"required_parameters": [{"name": "prompt", "type": "str"}],
109+
"examples": [
110+
{
111+
"scenario": "Test",
112+
"parameters": {"prompt": "Test Prompt"},
113+
}
114+
],
115+
}
116+
117+
with pytest.raises(ValueError):
118+
119+
@register_tool
120+
class NoDescription(Tool):
121+
name = "test_tool_"
122+
usage = {
123+
"required_parameters": [{"name": "prompt", "type": "str"}],
124+
"examples": [
125+
{
126+
"scenario": "Test",
127+
"parameters": {"prompt": "Test Prompt"},
128+
}
129+
],
130+
}
131+
132+
with pytest.raises(ValueError):
133+
134+
@register_tool
135+
class NoUsage(Tool):
136+
name = "test_tool_"
137+
description = "Test Tool"

vision_agent/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
SegArea,
1515
SegIoU,
1616
Tool,
17+
register_tool,
1718
)

vision_agent/tools/tools.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
from abc import ABC
44
from pathlib import Path
5-
from typing import Any, Dict, List, Tuple, Union, cast
5+
from typing import Any, Dict, List, Tuple, Type, Union, cast
66

77
import numpy as np
88
import requests
@@ -765,6 +765,30 @@ def __call__(self, equation: str) -> float:
765765
}
766766

767767

768+
def register_tool(tool: Type[Tool]) -> None:
769+
r"""Add a tool to the list of available tools.
770+
771+
Parameters:
772+
tool: The tool to add.
773+
"""
774+
775+
if (
776+
not hasattr(tool, "name")
777+
or not hasattr(tool, "description")
778+
or not hasattr(tool, "usage")
779+
):
780+
raise ValueError(
781+
"The tool must have 'name', 'description' and 'usage' attributes."
782+
)
783+
784+
TOOLS[len(TOOLS)] = {
785+
"name": tool.name,
786+
"description": tool.description,
787+
"usage": tool.usage,
788+
"class": tool,
789+
}
790+
791+
768792
def _send_inference_request(
769793
payload: Dict[str, Any], endpoint_name: str
770794
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)