Skip to content

Commit

Permalink
updated easytools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 20, 2024
1 parent f038363 commit a7f4b58
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

from vision_agent import LLM, LMM, OpenAILLM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -141,6 +141,10 @@ def answer_summarize(
return model(prompt)


def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
return tool()(**parameters)


def retrieval(
model: Union[LLM, LMM, Agent],
question: str,
Expand All @@ -167,28 +171,22 @@ def retrieval(
pass

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results = []
call_results: List[Any] = []
if isinstance(result["parameters"], Dict):
parameters = {}
for key in result["parameters"]:
parameters[change_name(key)] = result["parameters"][key]
# TODO: wrap call to handle errors
call_result = tools[tool_id]["class"]()(**parameters)
if call_result is None:
continue
call_results.append(call_result)
call_result = function_call(tools[tool_id]["class"], result["parameters"])
if call_result is None:
return call_results
call_results.append(call_result)
elif isinstance(result["parameters"], List):
for param_list in result["parameters"]:
parameters = {}
for key in param_list:
parameters[change_name(key)] = param_list[key]
call_result = tools[tool_id]["class"]()(**parameters)
for parameters in result["parameters"]:
call_result = function_call(tools[tool_id]["class"], parameters)
if call_result is None:
continue
call_results.append(call_result)
return call_results

call_results = []
__import__("ipdb").set_trace()
if isinstance(tool_results, Set) or isinstance(tool_results, List):
for result in tool_results:
call_results.extend(parse_tool_results(result))
Expand All @@ -210,6 +208,9 @@ class EasyTool(Agent):
>>> resp = agent("If a car is traveling at 64 km/h, how many kilometers does it travel in 29 minutes?")
>>> print(resp)
>>> "It will travel approximately 31.03 kilometers in 29 minutes."
>>> resp = agent("How many cards are in this image?", image="cards.jpg")
>>> print(resp)
>>> "There are 2 cards in this image."
"""

def __init__(
Expand Down Expand Up @@ -238,6 +239,8 @@ def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
question = chat[0]["content"]
if image:
question += f" Image name: {image}"
tasks = task_decompose(
self.task_model,
question,
Expand Down

0 comments on commit a7f4b58

Please sign in to comment.