Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new log with pick plan #180

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 105 additions & 36 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,29 @@ def pick_plan(
model: LMM,
code_interpreter: CodeInterpreter,
test_multi_plan: bool,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[Any, str, str]:
if not test_multi_plan:
k = list(plans.keys())[0]
log_progress(
{
"type": "logs",
"log_content": "Plans created",
"status": "completed",
"payload": plans[k],
}
)
return plans[k], tool_infos[k], ""

log_progress(
{
"type": "logs",
"log_content": "Generating code to pick best plan",
"status": "started",
}
)
all_tool_info = tool_infos["all"]
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand All @@ -194,6 +210,14 @@ def pick_plan(
)

code = extract_code(model(prompt))
log_progress(
{
"type": "logs",
"log_content": "Executing code to test plan",
"code": code,
"status": "running",
}
)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
Expand All @@ -203,6 +227,18 @@ def pick_plan(
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")

log_progress(
{
"type": "logs",
"log_content": (
"Code execution succeed"
if tool_output.success
else "Code execution failed"
),
"payload": tool_output.to_json(),
"status": "completed" if tool_output.success else "failed",
}
)
# retry if the tool output is empty or code fails
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
Expand All @@ -213,10 +249,33 @@ def pick_plan(
code=code, error=tool_output.text()
),
)
log_progress(
{
"type": "logs",
"log_content": "Retry running code",
"code": code,
"status": "running",
}
)
code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
log_progress(
{
"type": "logs",
"log_content": (
"Code execution succeed"
if tool_output.success
else "Code execution failed"
),
"code": code,
"payload": {
"result": tool_output.to_json(),
},
"status": "completed" if tool_output.success else "failed",
}
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
Expand Down Expand Up @@ -246,14 +305,26 @@ def pick_plan(

plan = best_plan["best_plan"]
if plan in plans and plan in tool_infos:
return plans[plan], tool_infos[plan], tool_output_str
best_plans = plans[plan]
best_tool_infos = tool_infos[plan]
else:
if verbosity >= 1:
_LOGGER.warning(
f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
return plans[k], tool_infos[k], tool_output_str
best_plans = plans[k]
best_tool_infos = tool_infos[k]

log_progress(
{
"type": "logs",
"log_content": "Picked best plan",
"status": "complete",
"payload": best_plans,
}
)
return best_plans, best_tool_infos, tool_output_str


@traceable
Expand Down Expand Up @@ -323,7 +394,8 @@ def write_and_test_code(
) -> Dict[str, Any]:
log_progress(
{
"type": "code",
"type": "log",
"log_content": "Generating code",
"status": "started",
}
)
Expand All @@ -341,10 +413,11 @@ def write_and_test_code(

log_progress(
{
"type": "code",
"type": "log",
"log_content": "Running code",
"status": "running",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
},
}
Expand All @@ -354,10 +427,13 @@ def write_and_test_code(
)
log_progress(
{
"type": "code",
"type": "log",
"log_content": (
"Code execution succeed" if result.success else "Code execution failed"
),
"status": "completed" if result.success else "failed",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": result.to_json(),
},
Expand Down Expand Up @@ -507,15 +583,8 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
def retrieve_tools(
plans: Dict[str, List[Dict[str, str]]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
) -> Dict[str, str]:
log_progress(
{
"type": "tools",
"status": "started",
}
)
) -> Tuple[Dict[str, str], Dict[str, List[Dict[str, str]]]]:
tool_info = []
tool_desc = []
tool_lists: Dict[str, List[Dict[str, str]]] = {}
Expand All @@ -526,7 +595,12 @@ def retrieve_tools(
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
{
"plan": task["instructions"] if index == 0 else "",
"tool": e["desc"].strip().split()[0],
"documentation": e["doc"],
}
for index, e in enumerate(tools)
)

if verbosity == 2:
Expand All @@ -540,14 +614,7 @@ def retrieve_tools(
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
log_progress(
{
"type": "tools",
"status": "completed",
"payload": tool_lists[list(plans.keys())[0]],
}
)
return tool_lists_unique
return tool_lists_unique, tool_lists


class VisionAgent(Agent):
Expand Down Expand Up @@ -704,7 +771,8 @@ def chat_with_workflow(

self.log_progress(
{
"type": "plans",
"type": "logs",
"log_content": "Creating plans",
"status": "started",
}
)
Expand All @@ -715,27 +783,28 @@ def chat_with_workflow(
self.planner,
)

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plans[list(plans.keys())[0]],
}
)

if self.verbosity >= 1 and test_multi_plan:
for p in plans:
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
tool_infos, tool_lists = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)

if test_multi_plan:
self.log_progress(
{
"type": "logs",
"log_content": "Creating plans",
"status": "completed",
"payload": tool_lists,
}
)

best_plan, best_tool_info, tool_output_str = pick_plan(
int_chat,
plans,
Expand Down Expand Up @@ -777,8 +846,8 @@ def chat_with_workflow(
{
"type": "final_code",
"status": "completed" if success else "failed",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": execution_result.to_json(),
},
Expand Down
Loading