diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index b21114ec..b0fc5f2c 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from vision_agent import LLM, LMM, OpenAILLM from vision_agent.tools import TOOLS @@ -141,6 +141,10 @@ def answer_summarize( return model(prompt) +def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any: + return tool()(**parameters) + + def retrieval( model: Union[LLM, LMM, Agent], question: str, @@ -167,28 +171,22 @@ def retrieval( pass def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: - call_results = [] + call_results: List[Any] = [] if isinstance(result["parameters"], Dict): - parameters = {} - for key in result["parameters"]: - parameters[change_name(key)] = result["parameters"][key] - # TODO: wrap call to handle errors - call_result = tools[tool_id]["class"]()(**parameters) - if call_result is None: - continue - call_results.append(call_result) + call_result = function_call(tools[tool_id]["class"], result["parameters"]) + if call_result is None: + return call_results + call_results.append(call_result) elif isinstance(result["parameters"], List): - for param_list in result["parameters"]: - parameters = {} - for key in param_list: - parameters[change_name(key)] = param_list[key] - call_result = tools[tool_id]["class"]()(**parameters) + for parameters in result["parameters"]: + call_result = function_call(tools[tool_id]["class"], parameters) if call_result is None: continue call_results.append(call_result) return call_results call_results = [] + __import__("ipdb").set_trace() if isinstance(tool_results, Set) or isinstance(tool_results, List): for result in tool_results: call_results.extend(parse_tool_results(result)) @@ -210,6 +208,9 @@ class EasyTool(Agent): >>> resp = agent("If a car is traveling at 64 km/h, how many kilometers does it travel in 29 minutes?") >>> print(resp) >>> "It will travel approximately 31.03 kilometers in 29 minutes." + >>> resp = agent("How many cards are in this image?", image="cards.jpg") + >>> print(resp) + >>> "There are 2 cards in this image." """ def __init__( @@ -238,6 +239,8 @@ def chat( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None ) -> str: question = chat[0]["content"] + if image: + question += f" Image name: {image}" tasks = task_decompose( self.task_model, question,