Skip to content

Commit

Permalink
added custom tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 22, 2024
1 parent c505b4e commit 03c2480
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
70 changes: 70 additions & 0 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import tempfile

import numpy as np
import pytest
from PIL import Image

from vision_agent.tools import TOOLS, Tool, register_tool
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU


Expand Down Expand Up @@ -65,3 +67,71 @@ def test_box_distance():
box1 = [0, 0, 2, 2]
box2 = [1, 1, 3, 3]
assert box_dist(box1, box2) == 0.0


def test_register_tool():
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"

@register_tool
class TestTool(Tool):
name = "test_tool_"
description = "Test Tool"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

def __call__(self, prompt: str) -> str:
return prompt

assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_"


def test_register_tool_incorrect():
with pytest.raises(ValueError):

@register_tool
class NoAttributes(Tool):
pass

with pytest.raises(ValueError):

@register_tool
class NoName(Tool):
description = "Test Tool"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

with pytest.raises(ValueError):

@register_tool
class NoDescription(Tool):
name = "test_tool_"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

with pytest.raises(ValueError):

@register_tool
class NoUsage(Tool):
name = "test_tool_"
description = "Test Tool"
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
SegArea,
SegIoU,
Tool,
register_tool,
)
26 changes: 25 additions & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tempfile
from abc import ABC
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Tuple, Type, Union, cast

import numpy as np
import requests
Expand Down Expand Up @@ -765,6 +765,30 @@ def __call__(self, equation: str) -> float:
}


def register_tool(tool: Type[Tool]) -> None:
r"""Add a tool to the list of available tools.
Parameters:
tool: The tool to add.
"""

if (
not hasattr(tool, "name")
or not hasattr(tool, "description")
or not hasattr(tool, "usage")
):
raise ValueError(
"The tool must have 'name', 'description' and 'usage' attributes."
)

TOOLS[len(TOOLS)] = {
"name": tool.name,
"description": tool.description,
"usage": tool.usage,
"class": tool,
}


def _send_inference_request(
payload: Dict[str, Any], endpoint_name: str
) -> Dict[str, Any]:
Expand Down

0 comments on commit 03c2480

Please sign in to comment.