Skip to content

Commit

Permalink
Add tool testing (#164)
Browse files Browse the repository at this point in the history
* adding image support for va

* save

* removed reflection

* image support for vision agent planning

* handle more media types for lmm

* final changes

* fix flake8

* format fix

* add tool testing

* remove array types from printed tool results

* fixed bug in prompt

* remove trailing space
  • Loading branch information
dillonalaird authored Jul 10, 2024
1 parent 3e51cd9 commit 169d650
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 123 deletions.
249 changes: 163 additions & 86 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

from langsmith import traceable
from PIL import Image
Expand All @@ -20,9 +20,11 @@
CODE,
FIX_BUG,
FULL_TASK,
PICK_PLAN,
PLAN,
REFLECT,
PREVIOUS_FAILED,
SIMPLE_TEST,
TEST_PLANS,
USER_REQ,
)
from vision_agent.lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
Expand Down Expand Up @@ -80,6 +82,15 @@ def format_memory(memory: List[Dict[str, str]]) -> str:
return output_str


def format_plans(plans: Dict[str, Any]) -> str:
plan_str = ""
for k, v in plans.items():
plan_str += f"{k}:\n"
plan_str += "-" + "\n-".join([e["instructions"] for e in v])

return plan_str


def extract_code(code: str) -> str:
if "\n```python" in code:
start = "\n```python"
Expand Down Expand Up @@ -140,12 +151,12 @@ def extract_image(


@traceable
def write_plan(
def write_plans(
chat: List[Message],
tool_desc: str,
working_memory: str,
model: LMM,
) -> List[Dict[str, str]]:
) -> Dict[str, Any]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")
Expand All @@ -154,14 +165,84 @@ def write_plan(
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
return extract_json(model.chat(chat))["plan"] # type: ignore
return extract_json(model.chat(chat))


@traceable
def pick_plan(
chat: List[Message],
plans: Dict[str, Any],
tool_info: str,
model: LMM,
code_interpreter: CodeInterpreter,
verbosity: int = 0,
) -> Tuple[str, str]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

plan_str = format_plans(plans)
prompt = TEST_PLANS.format(
docstring=tool_info, plans=plan_str, previous_attempts=""
)

code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity >= 1:
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")

# retry if the tool output is empty or code fails
count = 1
while (not tool_output.success or tool_output_str == "") and count < 3:
prompt = TEST_PLANS.format(
docstring=tool_info,
plans=plan_str,
previous_attempts=PREVIOUS_FAILED.format(
code=code, error=tool_output.text()
),
)
code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]

if verbosity == 1:
_print_code("Code and test after attempted fix:", code)
_LOGGER.info(f"Code execution result after attempte {count}")

count += 1

user_req = chat[-1]["content"]
context = USER_REQ.format(user_request=user_req)
# because the tool picker model gets the image as well, we have to be careful with
# how much text we send it, so we truncate the tool output to 20,000 characters
prompt = PICK_PLAN.format(
context=context,
plans=format_plans(plans),
tool_output=tool_output_str[:20_000],
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat))
if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
return best_plan["best_plan"], tool_output_str


@traceable
def write_code(
coder: LMM,
chat: List[Message],
plan: str,
tool_info: str,
tool_output: str,
feedback: str,
) -> str:
chat = copy.deepcopy(chat)
Expand All @@ -171,7 +252,8 @@ def write_code(
user_request = chat[-1]["content"]
prompt = CODE.format(
docstring=tool_info,
question=user_request,
question=FULL_TASK.format(user_request=user_request, subtasks=plan),
tool_output=tool_output,
feedback=feedback,
)
chat[-1]["content"] = prompt
Expand Down Expand Up @@ -203,27 +285,11 @@ def write_test(
return extract_code(tester(chat))


@traceable
def reflect(
chat: List[Message],
plan: str,
code: str,
model: LMM,
) -> Dict[str, Union[str, bool]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

user_request = chat[-1]["content"]
context = USER_REQ.format(user_request=user_request)
prompt = REFLECT.format(context=context, plan=plan, code=code)
chat[-1]["content"] = prompt
return extract_json(model(chat))


def write_and_test_code(
chat: List[Message],
plan: str,
tool_info: str,
tool_output: str,
tool_utils: str,
working_memory: List[Dict[str, str]],
coder: LMM,
Expand All @@ -241,7 +307,14 @@ def write_and_test_code(
"status": "started",
}
)
code = write_code(coder, chat, tool_info, format_memory(working_memory))
code = write_code(
coder,
chat,
plan,
tool_info,
tool_output,
format_memory(working_memory),
)
test = write_test(
tester, chat, tool_utils, code, format_memory(working_memory), media
)
Expand Down Expand Up @@ -412,11 +485,11 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:


def retrieve_tools(
plan: List[Dict[str, str]],
plans: Dict[str, List[Dict[str, str]]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
) -> str:
) -> Dict[str, str]:
log_progress(
{
"type": "tools",
Expand All @@ -425,27 +498,29 @@ def retrieve_tools(
)
tool_info = []
tool_desc = []
tool_list: List[Dict[str, str]] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_list.extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
)
log_progress(
{
"type": "tools",
"status": "completed",
"payload": list({v["description"]: v for v in tool_list}.values()),
}
)
tool_lists: Dict[str, List[Dict[str, str]]] = {}
for k, plan in plans.items():
tool_lists[k] = []
for task in plan:
tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3)
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
)

if verbosity == 2:
tool_desc_str = "\n".join(set(tool_desc))
_LOGGER.info(f"Tools Description:\n{tool_desc_str}")
tool_info_set = set(tool_info)
return "\n\n".join(tool_info_set)

tool_lists_unique = {}
for k in tool_lists:
tool_lists_unique[k] = "\n\n".join(
set(e["documentation"] for e in tool_lists[k])
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
return tool_lists_unique


class VisionAgent(Agent):
Expand Down Expand Up @@ -543,7 +618,6 @@ def __call__(
def chat_with_workflow(
self,
chat: List[Message],
self_reflection: bool = False,
display_visualization: bool = False,
) -> Dict[str, Any]:
"""Chat with Vision Agent and return intermediate information regarding the task.
Expand All @@ -554,7 +628,6 @@ def chat_with_workflow(
[{"role": "user", "content": "describe your task here..."}]
or if it contains media files, it should be in the format of:
[{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
self_reflection (bool): Whether to reflect on the task and debug the code.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
Expand All @@ -581,7 +654,10 @@ def chat_with_workflow(

int_chat = cast(
List[Message],
[{"role": c["role"], "content": c["content"]} for c in chat],
[
{"role": c["role"], "content": c["content"], "media": c["media"]}
for c in chat
],
)

code = ""
Expand All @@ -599,13 +675,45 @@ def chat_with_workflow(
"status": "started",
}
)
plan_i = write_plan(
plans = write_plans(
int_chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
)
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])

if self.verbosity >= 1:
for p in plans:
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
Expand All @@ -616,18 +724,16 @@ def chat_with_workflow(
)
if self.verbosity >= 1:
_LOGGER.info(
f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_info = retrieve_tools(
plan_i,
self.tool_recommender,
self.log_progress,
self.verbosity,
)
results = write_and_test_code(
chat=int_chat,
chat=[
{"role": c["role"], "content": c["content"]} for c in int_chat
],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
coder=self.coder,
Expand All @@ -644,35 +750,6 @@ def chat_with_workflow(
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if not self_reflection:
break

self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
int_chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})
retries += 1

execution_result = cast(Execution, results["test_result"])
Expand Down
Loading

0 comments on commit 169d650

Please sign in to comment.