Skip to content

Commit

Permalink
feat: new log with pick plan (#180)
Browse files Browse the repository at this point in the history
* feat: new log with pick plan

* fix lint

* format

* fix lint

* address comment
  • Loading branch information
wuyiqunLu authored Jul 26, 2024
1 parent 0f616e0 commit 82c9b6f
Showing 1 changed file with 105 additions and 36 deletions.
141 changes: 105 additions & 36 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,29 @@ def pick_plan(
model: LMM,
code_interpreter: CodeInterpreter,
test_multi_plan: bool,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[Any, str, str]:
if not test_multi_plan:
k = list(plans.keys())[0]
log_progress(
{
"type": "logs",
"log_content": "Plans created",
"status": "completed",
"payload": plans[k],
}
)
return plans[k], tool_infos[k], ""

log_progress(
{
"type": "logs",
"log_content": "Generating code to pick best plan",
"status": "started",
}
)
all_tool_info = tool_infos["all"]
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
Expand All @@ -194,6 +210,14 @@ def pick_plan(
)

code = extract_code(model(prompt))
log_progress(
{
"type": "logs",
"log_content": "Executing code to test plan",
"code": code,
"status": "running",
}
)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
Expand All @@ -203,6 +227,18 @@ def pick_plan(
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")

log_progress(
{
"type": "logs",
"log_content": (
"Code execution succeed"
if tool_output.success
else "Code execution failed"
),
"payload": tool_output.to_json(),
"status": "completed" if tool_output.success else "failed",
}
)
# retry if the tool output is empty or code fails
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
Expand All @@ -213,10 +249,33 @@ def pick_plan(
code=code, error=tool_output.text()
),
)
log_progress(
{
"type": "logs",
"log_content": "Retry running code",
"code": code,
"status": "running",
}
)
code = extract_code(model(prompt))
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
log_progress(
{
"type": "logs",
"log_content": (
"Code execution succeed"
if tool_output.success
else "Code execution failed"
),
"code": code,
"payload": {
"result": tool_output.to_json(),
},
"status": "completed" if tool_output.success else "failed",
}
)
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
tool_output_str = tool_output.logs.stdout[0]
Expand Down Expand Up @@ -246,14 +305,26 @@ def pick_plan(

plan = best_plan["best_plan"]
if plan in plans and plan in tool_infos:
return plans[plan], tool_infos[plan], tool_output_str
best_plans = plans[plan]
best_tool_infos = tool_infos[plan]
else:
if verbosity >= 1:
_LOGGER.warning(
f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
return plans[k], tool_infos[k], tool_output_str
best_plans = plans[k]
best_tool_infos = tool_infos[k]

log_progress(
{
"type": "logs",
"log_content": "Picked best plan",
"status": "complete",
"payload": best_plans,
}
)
return best_plans, best_tool_infos, tool_output_str


@traceable
Expand Down Expand Up @@ -323,7 +394,8 @@ def write_and_test_code(
) -> Dict[str, Any]:
log_progress(
{
"type": "code",
"type": "log",
"log_content": "Generating code",
"status": "started",
}
)
Expand All @@ -341,10 +413,11 @@ def write_and_test_code(

log_progress(
{
"type": "code",
"type": "log",
"log_content": "Running code",
"status": "running",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
},
}
Expand All @@ -354,10 +427,13 @@ def write_and_test_code(
)
log_progress(
{
"type": "code",
"type": "log",
"log_content": (
"Code execution succeed" if result.success else "Code execution failed"
),
"status": "completed" if result.success else "failed",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": result.to_json(),
},
Expand Down Expand Up @@ -507,15 +583,8 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None:
def retrieve_tools(
plans: Dict[str, List[Dict[str, str]]],
tool_recommender: Sim,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
) -> Dict[str, str]:
log_progress(
{
"type": "tools",
"status": "started",
}
)
) -> Tuple[Dict[str, str], Dict[str, List[Dict[str, str]]]]:
tool_info = []
tool_desc = []
tool_lists: Dict[str, List[Dict[str, str]]] = {}
Expand All @@ -526,7 +595,12 @@ def retrieve_tools(
tool_info.extend([e["doc"] for e in tools])
tool_desc.extend([e["desc"] for e in tools])
tool_lists[k].extend(
{"description": e["desc"], "documentation": e["doc"]} for e in tools
{
"plan": task["instructions"] if index == 0 else "",
"tool": e["desc"].strip().split()[0],
"documentation": e["doc"],
}
for index, e in enumerate(tools)
)

if verbosity == 2:
Expand All @@ -540,14 +614,7 @@ def retrieve_tools(
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
log_progress(
{
"type": "tools",
"status": "completed",
"payload": tool_lists[list(plans.keys())[0]],
}
)
return tool_lists_unique
return tool_lists_unique, tool_lists


class VisionAgent(Agent):
Expand Down Expand Up @@ -704,7 +771,8 @@ def chat_with_workflow(

self.log_progress(
{
"type": "plans",
"type": "logs",
"log_content": "Creating plans",
"status": "started",
}
)
Expand All @@ -715,27 +783,28 @@ def chat_with_workflow(
self.planner,
)

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plans[list(plans.keys())[0]],
}
)

if self.verbosity >= 1 and test_multi_plan:
for p in plans:
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_infos = retrieve_tools(
tool_infos, tool_lists = retrieve_tools(
plans,
self.tool_recommender,
self.log_progress,
self.verbosity,
)

if test_multi_plan:
self.log_progress(
{
"type": "logs",
"log_content": "Creating plans",
"status": "completed",
"payload": tool_lists,
}
)

best_plan, best_tool_info, tool_output_str = pick_plan(
int_chat,
plans,
Expand Down Expand Up @@ -777,8 +846,8 @@ def chat_with_workflow(
{
"type": "final_code",
"status": "completed" if success else "failed",
"code": DefaultImports.prepend_imports(code),
"payload": {
"code": DefaultImports.prepend_imports(code),
"test": test,
"result": execution_result.to_json(),
},
Expand Down

0 comments on commit 82c9b6f

Please sign in to comment.