From 6cd4676bbe75b5f78722f1ef0783134216474d16 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 20 Mar 2024 10:53:36 -0700 Subject: [PATCH] minor fixes --- vision_agent/agent/easytool.py | 85 +++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index b0fc5f2c..929689ad 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -1,6 +1,8 @@ import json +import logging +import sys from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from vision_agent import LLM, LMM, OpenAILLM from vision_agent.tools import TOOLS @@ -15,6 +17,9 @@ TASK_TOPOLOGY, ) +logging.basicConfig(stream=sys.stdout) +_LOGGER = logging.getLogger(__name__) + def parse_json(s: str) -> Any: s = ( @@ -45,7 +50,7 @@ def format_tools(tools: Dict[int, Any]) -> str: def task_decompose( model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any] -) -> Dict: +) -> Optional[Dict]: prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools)) tries = 0 str_result = "" @@ -56,7 +61,8 @@ def task_decompose( return result["Tasks"] # type: ignore except Exception: if tries > 10: - raise ValueError(f"Failed task_decompose on: {str_result}") + _LOGGER.error(f"Failed task_decompose on: {str_result}") + return None tries += 1 continue @@ -81,14 +87,15 @@ def task_topology( return result["Tasks"] # type: ignore except Exception: if tries > 10: - raise ValueError(f"Failed task_topology on: {str_result}") + _LOGGER.error(f"Failed task_topology on: {str_result}") + return task_list tries += 1 continue def choose_tool( model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any] -) -> int: +) -> Optional[int]: prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools)) tries = 0 str_result = "" @@ -99,14 +106,15 @@ def choose_tool( return result["ID"] # type: ignore except Exception: if tries > 10: - raise ValueError(f"Failed choose_tool on: {str_result}") + _LOGGER.error(f"Failed choose_tool on: {str_result}") + return None tries += 1 continue def choose_parameter( model: Union[LLM, LMM, Agent], question: str, tool_usage: Dict, previous_log: str -) -> Any: +) -> Optional[Any]: # TODO: should format tool_usage prompt = CHOOSE_PARAMETER.format( question=question, tool_usage=tool_usage, previous_log=previous_log @@ -120,7 +128,8 @@ def choose_parameter( return result["Parameters"] except Exception: if tries > 10: - raise ValueError(f"Failed choose_parameter on: {str_result}") + _LOGGER.error(f"Failed choose_parameter on: {str_result}") + return None tries += 1 continue @@ -155,20 +164,21 @@ def retrieval( tool_id = choose_tool( model, question, {k: v["description"] for k, v in tools.items()} ) - if tool_id is None: # TODO - pass + if tool_id is None: + return [{}], "" + _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})") tool_instructions = tools[tool_id] tool_usage = tool_instructions["usage"] tool_name = tool_instructions["name"] parameters = choose_parameter(model, question, tool_usage, previous_log) - if parameters is None: # TODO - pass - tool_results = [{"tool_name": tool_name, "parameters": parameters}] - - if len(tool_results) == 0: # TODO - pass + _LOGGER.info(f"\tParameters: {parameters} for {tool_name}") + if parameters is None: + return [{}], "" + tool_results = [ + {"task": question, "tool_name": tool_name, "parameters": parameters} + ] def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: call_results: List[Any] = [] @@ -186,14 +196,12 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: return call_results call_results = [] - __import__("ipdb").set_trace() - if isinstance(tool_results, Set) or isinstance(tool_results, List): - for result in tool_results: - call_results.extend(parse_tool_results(result)) - elif isinstance(tool_results, Dict): - call_results.extend(parse_tool_results(tool_results)) + for i, result in enumerate(tool_results): + call_results.extend(parse_tool_results(result)) + tool_results[i]["call_results"] = call_results call_results_str = "\n\n".join([str(e) for e in call_results]) + _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str @@ -217,6 +225,7 @@ def __init__( self, task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, + verbose: bool = False, ): self.task_model = ( OpenAILLM(json_mode=True) if task_model is None else task_model @@ -225,6 +234,8 @@ def __init__( self.retrieval_num = 3 self.tools = TOOLS + if verbose: + _LOGGER.setLevel(logging.INFO) def __call__( self, @@ -235,9 +246,9 @@ def __call__( input = [{"role": "user", "content": input}] return self.chat(input, image=image) - def chat( + def chat_with_workflow( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None - ) -> str: + ) -> Tuple[str, List[Dict]]: question = chat[0]["content"] if image: question += f" Image name: {image}" @@ -246,17 +257,25 @@ def chat( question, {k: v["description"] for k, v in self.tools.items()}, ) - task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] - task_list = task_topology(self.task_model, question, task_list) + _LOGGER.info(f"Tasks: {tasks}") + if tasks is not None: + task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)] + task_list = task_topology(self.task_model, question, task_list) + else: + task_list = [] + + _LOGGER.info(f"Task Dependency: {task_list}") task_depend = {"Original Quesiton": question} previous_log = "" answers = [] for task in task_list: task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore # TODO topological sort task_list + all_tool_results = [] for task in task_list: task_str = task["task"] previous_log = str(task_depend) + _LOGGER.info(f"\tSubtask: {task_str}") tool_results, call_results = retrieval( self.task_model, task_str, @@ -266,6 +285,18 @@ def chat( answer = answer_generate( self.answer_model, task_str, call_results, previous_log ) + + for tool_result in tool_results: + tool_result["answer"] = answer + all_tool_results.extend(tool_results) + + _LOGGER.info(f"\tAnswer: {answer}") answers.append({"task": task_str, "answer": answer}) task_depend[task["id"]]["answer"] = answer # type: ignore - return answer_summarize(self.answer_model, question, answers) + return answer_summarize(self.answer_model, question, answers), all_tool_results + + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + answer, _ = self.chat_with_workflow(chat, image=image) + return answer