Skip to content

Commit

Permalink
add tool testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Jul 8, 2024
1 parent 7d0f7e9 commit ad20576
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 55 deletions.
186 changes: 148 additions & 38 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

from langsmith import traceable
from PIL import Image
Expand All @@ -19,8 +19,12 @@
from vision_agent.agent.vision_agent_prompts import (
CODE,
FIX_BUG,
FULL_TASK,
PICK_PLAN,
PLAN,
PREVIOUS_FAILED,
SIMPLE_TEST,
TEST_PLANS,
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
Expand Down Expand Up @@ -78,6 +82,15 @@ def format_memory(memory: List[Dict[str, str]]) -> str:
return output_str


def format_plans(plans: Dict[str, Any]) -> str:
plan_str = ""
for k, v in plans.items():
plan_str += f"{k}:\n"
plan_str += "-" + "\n-".join([e["instructions"] for e in v])

return plan_str


def extract_code(code: str) -> str:
if "\n```python" in code:
start = "\n```python"
Expand Down Expand Up @@ -138,7 +151,7 @@ def extract_image(


@traceable
def write_plan(
def write_plans(
chat: List[Message],
tool_desc: str,
working_memory: str,
Expand All @@ -155,12 +168,79 @@ def write_plan(
return extract_json(model.chat(chat))


@traceable
def pick_plan(
chat: List[Message],
plans: Dict[str, Any],
tool_info: str,
model: LMM,
code_interpreter: CodeInterpreter,
verbosity: int = 0,
) -> Tuple[str, str]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

plan_str = format_plans(plans)
prompt = TEST_PLANS.format(
docstring=tool_info, plans=plan_str, previous_attempts=""
)

code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity >= 1:
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")

# retry if the tool output is empty or code fails
count = 1
while (not tool_output.success or tool_output_str == "") and count < 3:
prompt = TEST_PLANS.format(
docstring=tool_info,
plans=plan_str,
previous_attempts=PREVIOUS_FAILED.format(
code=code, error=tool_output.text()
),
)
code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity == 1:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempte {count}")

count += 1

user_req = chat[-1]["content"]
context = USER_REQ.format(user_request=user_req)
# because the tool picker model gets the image as well, we have to be careful with
# how much text we send it, so we truncate the tool output to 500 characters
prompt = PICK_PLAN.format(
context=context,
plans=format_plans(plans),
tool_output=tool_output_str[:500],
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat))
return best_plan["best_plan"], tool_output_str


@traceable
def write_code(
coder: LMM,
chat: List[Message],
image_desc: str,
plan: str,
tool_info: str,
tool_output: str,
feedback: str,
) -> str:
chat = copy.deepcopy(chat)
Expand All @@ -170,9 +250,9 @@ def write_code(
user_request = chat[-1]["content"]
prompt = CODE.format(
docstring=tool_info,
question=user_request,
question=FULL_TASK.format(user_request=user_request, subtasks=plan),
tool_output=tool_output,
feedback=feedback,
image_desc=image_desc,
)
chat[-1]["content"] = prompt
return extract_code(coder(chat))
Expand Down Expand Up @@ -205,8 +285,9 @@ def write_test(

def write_and_test_code(
chat: List[Message],
image_desc: str,
plan: str,
tool_info: str,
tool_output: str,
tool_utils: str,
working_memory: List[Dict[str, str]],
coder: LMM,
Expand All @@ -227,8 +308,9 @@ def write_and_test_code(
code = write_code(
coder,
chat,
image_desc,
f"{tool_info}\n{tool_utils}",
plan,
tool_info,
tool_output,
format_memory(working_memory),
)
test = write_test(
Expand Down Expand Up @@ -401,11 +483,11 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:


def retrieve_tools(
plan: List[Dict[str, str]],
plans: Dict[str, List[Dict[str, str]]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
) -> str:
) -> Dict[str, str]:
log_progress(
{
"type": "tools",
Expand All @@ -414,27 +496,29 @@ def retrieve_tools(
)
tool_info = []
tool_desc = []
tool_list: List[Dict[str, str]] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_list.extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
)
log_progress(
{
"type": "tools",
"status": "completed",
"payload": list({v["description"]: v for v in tool_list}.values()),
}
)
tool_lists: Dict[str, List[Dict[str, str]]] = {}
for k, plan in plans.items():
tool_lists[k] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
)

if verbosity == 2:
tool_desc_str = "\n".join(set(tool_desc))
_LOGGER.info(f"Tools Description:\n{tool_desc_str}")
tool_info_set = set(tool_info)
return "\n\n".join(tool_info_set)

tool_lists_unique = {}
for k in tool_lists:
tool_lists_unique[k] = "\n\n".join(
set(e["documentation"] for e in tool_lists[k])
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
return tool_lists_unique


class VisionAgent(Agent):
Expand Down Expand Up @@ -589,14 +673,45 @@ def chat_with_workflow(
"status": "started",
}
)
planning = write_plan(
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
)
plan_i = planning["plan"]
image_desc = planning["image_desc"]

if self.verbosity >= 1:
for p in plans:
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
Expand All @@ -607,21 +722,16 @@ def chat_with_workflow(
)
if self.verbosity >= 1:
_LOGGER.info(
f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
chat=[
{"role": c["role"], "content": c["content"]} for c in int_chat
],
image_desc=image_desc,
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
coder=self.coder,
Expand Down
Loading

0 comments on commit ad20576

Please sign in to comment.