Skip to content

Commit

Permalink
Fixes for EasyTool (#21)
Browse files Browse the repository at this point in the history
minor fixes
  • Loading branch information
dillonalaird authored Mar 20, 2024
1 parent 38a2e27 commit dcb5d72
Showing 1 changed file with 58 additions and 27 deletions.
85 changes: 58 additions & 27 deletions vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from vision_agent import LLM, LMM, OpenAILLM
from vision_agent.tools import TOOLS
Expand All @@ -15,6 +17,9 @@
TASK_TOPOLOGY,
)

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)


def parse_json(s: str) -> Any:
s = (
Expand Down Expand Up @@ -45,7 +50,7 @@ def format_tools(tools: Dict[int, Any]) -> str:

def task_decompose(
model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any]
) -> Dict:
) -> Optional[Dict]:
prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools))
tries = 0
str_result = ""
Expand All @@ -56,7 +61,8 @@ def task_decompose(
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_decompose on: {str_result}")
_LOGGER.error(f"Failed task_decompose on: {str_result}")
return None
tries += 1
continue

Expand All @@ -81,14 +87,15 @@ def task_topology(
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_topology on: {str_result}")
_LOGGER.error(f"Failed task_topology on: {str_result}")
return task_list
tries += 1
continue


def choose_tool(
model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any]
) -> int:
) -> Optional[int]:
prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools))
tries = 0
str_result = ""
Expand All @@ -99,14 +106,15 @@ def choose_tool(
return result["ID"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed choose_tool on: {str_result}")
_LOGGER.error(f"Failed choose_tool on: {str_result}")
return None
tries += 1
continue


def choose_parameter(
model: Union[LLM, LMM, Agent], question: str, tool_usage: Dict, previous_log: str
) -> Any:
) -> Optional[Any]:
# TODO: should format tool_usage
prompt = CHOOSE_PARAMETER.format(
question=question, tool_usage=tool_usage, previous_log=previous_log
Expand All @@ -120,7 +128,8 @@ def choose_parameter(
return result["Parameters"]
except Exception:
if tries > 10:
raise ValueError(f"Failed choose_parameter on: {str_result}")
_LOGGER.error(f"Failed choose_parameter on: {str_result}")
return None
tries += 1
continue

Expand Down Expand Up @@ -155,20 +164,21 @@ def retrieval(
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}
)
if tool_id is None: # TODO
pass
if tool_id is None:
return [{}], ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(model, question, tool_usage, previous_log)
if parameters is None: # TODO
pass
tool_results = [{"tool_name": tool_name, "parameters": parameters}]

if len(tool_results) == 0: # TODO
pass
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return [{}], ""
tool_results = [
{"task": question, "tool_name": tool_name, "parameters": parameters}
]

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
Expand All @@ -186,14 +196,12 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
return call_results

call_results = []
__import__("ipdb").set_trace()
if isinstance(tool_results, Set) or isinstance(tool_results, List):
for result in tool_results:
call_results.extend(parse_tool_results(result))
elif isinstance(tool_results, Dict):
call_results.extend(parse_tool_results(tool_results))
for i, result in enumerate(tool_results):
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results])
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


Expand All @@ -217,6 +225,7 @@ def __init__(
self,
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
verbose: bool = False,
):
self.task_model = (
OpenAILLM(json_mode=True) if task_model is None else task_model
Expand All @@ -225,6 +234,8 @@ def __init__(

self.retrieval_num = 3
self.tools = TOOLS
if verbose:
_LOGGER.setLevel(logging.INFO)

def __call__(
self,
Expand All @@ -235,9 +246,9 @@ def __call__(
input = [{"role": "user", "content": input}]
return self.chat(input, image=image)

def chat(
def chat_with_workflow(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
) -> Tuple[str, List[Dict]]:
question = chat[0]["content"]
if image:
question += f" Image name: {image}"
Expand All @@ -246,17 +257,25 @@ def chat(
question,
{k: v["description"] for k, v in self.tools.items()},
)
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(self.task_model, question, task_list)
_LOGGER.info(f"Tasks: {tasks}")
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(self.task_model, question, task_list)
else:
task_list = []

_LOGGER.info(f"Task Dependency: {task_list}")
task_depend = {"Original Quesiton": question}
previous_log = ""
answers = []
for task in task_list:
task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore
# TODO topological sort task_list
all_tool_results = []
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
_LOGGER.info(f"\tSubtask: {task_str}")
tool_results, call_results = retrieval(
self.task_model,
task_str,
Expand All @@ -266,6 +285,18 @@ def chat(
answer = answer_generate(
self.answer_model, task_str, call_results, previous_log
)

for tool_result in tool_results:
tool_result["answer"] = answer
all_tool_results.extend(tool_results)

_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
return answer_summarize(self.answer_model, question, answers)
return answer_summarize(self.answer_model, question, answers), all_tool_results

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
answer, _ = self.chat_with_workflow(chat, image=image)
return answer

0 comments on commit dcb5d72

Please sign in to comment.