Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for EasyTool #21

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from vision_agent import LLM, LMM, OpenAILLM
from vision_agent.tools import TOOLS
Expand All @@ -15,6 +17,9 @@
TASK_TOPOLOGY,
)

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)


def parse_json(s: str) -> Any:
s = (
Expand Down Expand Up @@ -45,7 +50,7 @@ def format_tools(tools: Dict[int, Any]) -> str:

def task_decompose(
model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any]
) -> Dict:
) -> Optional[Dict]:
prompt = TASK_DECOMPOSE.format(question=question, tools=format_tools(tools))
tries = 0
str_result = ""
Expand All @@ -56,7 +61,8 @@ def task_decompose(
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_decompose on: {str_result}")
_LOGGER.error(f"Failed task_decompose on: {str_result}")
return None
tries += 1
continue

Expand All @@ -81,14 +87,15 @@ def task_topology(
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_topology on: {str_result}")
_LOGGER.error(f"Failed task_topology on: {str_result}")
return task_list
tries += 1
continue


def choose_tool(
model: Union[LLM, LMM, Agent], question: str, tools: Dict[int, Any]
) -> int:
) -> Optional[int]:
prompt = CHOOSE_TOOL.format(question=question, tools=format_tools(tools))
tries = 0
str_result = ""
Expand All @@ -99,14 +106,15 @@ def choose_tool(
return result["ID"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed choose_tool on: {str_result}")
_LOGGER.error(f"Failed choose_tool on: {str_result}")
return None
tries += 1
continue


def choose_parameter(
model: Union[LLM, LMM, Agent], question: str, tool_usage: Dict, previous_log: str
) -> Any:
) -> Optional[Any]:
# TODO: should format tool_usage
prompt = CHOOSE_PARAMETER.format(
question=question, tool_usage=tool_usage, previous_log=previous_log
Expand All @@ -120,7 +128,8 @@ def choose_parameter(
return result["Parameters"]
except Exception:
if tries > 10:
raise ValueError(f"Failed choose_parameter on: {str_result}")
_LOGGER.error(f"Failed choose_parameter on: {str_result}")
return None
tries += 1
continue

Expand Down Expand Up @@ -155,20 +164,21 @@ def retrieval(
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}
)
if tool_id is None: # TODO
pass
if tool_id is None:
return [{}], ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
tool_usage = tool_instructions["usage"]
tool_name = tool_instructions["name"]

parameters = choose_parameter(model, question, tool_usage, previous_log)
if parameters is None: # TODO
pass
tool_results = [{"tool_name": tool_name, "parameters": parameters}]

if len(tool_results) == 0: # TODO
pass
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return [{}], ""
tool_results = [
{"task": question, "tool_name": tool_name, "parameters": parameters}
]

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results: List[Any] = []
Expand All @@ -186,14 +196,12 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
return call_results

call_results = []
__import__("ipdb").set_trace()
if isinstance(tool_results, Set) or isinstance(tool_results, List):
for result in tool_results:
call_results.extend(parse_tool_results(result))
elif isinstance(tool_results, Dict):
call_results.extend(parse_tool_results(tool_results))
for i, result in enumerate(tool_results):
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results])
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str


Expand All @@ -217,6 +225,7 @@ def __init__(
self,
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
verbose: bool = False,
):
self.task_model = (
OpenAILLM(json_mode=True) if task_model is None else task_model
Expand All @@ -225,6 +234,8 @@ def __init__(

self.retrieval_num = 3
self.tools = TOOLS
if verbose:
_LOGGER.setLevel(logging.INFO)

def __call__(
self,
Expand All @@ -235,9 +246,9 @@ def __call__(
input = [{"role": "user", "content": input}]
return self.chat(input, image=image)

def chat(
def chat_with_workflow(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
) -> Tuple[str, List[Dict]]:
question = chat[0]["content"]
if image:
question += f" Image name: {image}"
Expand All @@ -246,17 +257,25 @@ def chat(
question,
{k: v["description"] for k, v in self.tools.items()},
)
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(self.task_model, question, task_list)
_LOGGER.info(f"Tasks: {tasks}")
if tasks is not None:
task_list = [{"task": task, "id": i + 1} for i, task in enumerate(tasks)]
task_list = task_topology(self.task_model, question, task_list)
else:
task_list = []

_LOGGER.info(f"Task Dependency: {task_list}")
task_depend = {"Original Quesiton": question}
previous_log = ""
answers = []
for task in task_list:
task_depend[task["id"]] = {"task": task["task"], "answer": ""} # type: ignore
# TODO topological sort task_list
all_tool_results = []
for task in task_list:
task_str = task["task"]
previous_log = str(task_depend)
_LOGGER.info(f"\tSubtask: {task_str}")
tool_results, call_results = retrieval(
self.task_model,
task_str,
Expand All @@ -266,6 +285,18 @@ def chat(
answer = answer_generate(
self.answer_model, task_str, call_results, previous_log
)

for tool_result in tool_results:
tool_result["answer"] = answer
all_tool_results.extend(tool_results)

_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
task_depend[task["id"]]["answer"] = answer # type: ignore
return answer_summarize(self.answer_model, question, answers)
return answer_summarize(self.answer_model, question, answers), all_tool_results

def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
answer, _ = self.chat_with_workflow(chat, image=image)
return answer
Loading