Skip to content

Commit

Permalink
fix: add back some logging for ui (#176)
Browse files Browse the repository at this point in the history
* fix: add back some logging for ui

* simplifier

* resolve lint error

* fix

* fix type error

* fix type error

* fix type error

* remove check
  • Loading branch information
wuyiqunLu authored Jul 22, 2024
1 parent 6c173d9 commit c8fc196
Showing 1 changed file with 49 additions and 40 deletions.
89 changes: 49 additions & 40 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,25 @@ def write_plans(
def pick_plan(
chat: List[Message],
plans: Dict[str, Any],
tool_info: str,
tool_infos: Dict[str, str],
model: LMM,
code_interpreter: CodeInterpreter,
test_multi_plan: bool,
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[str, str]:
) -> Tuple[Any, str, str]:
if not test_multi_plan:
k = list(plans.keys())[0]
return plans[k], tool_infos[k], ""

all_tool_info = tool_infos["all"]
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

plan_str = format_plans(plans)
prompt = TEST_PLANS.format(
docstring=tool_info, plans=plan_str, previous_attempts=""
docstring=all_tool_info, plans=plan_str, previous_attempts=""
)

code = extract_code(model(prompt))
Expand All @@ -201,7 +207,7 @@ def pick_plan(
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
prompt = TEST_PLANS.format(
docstring=tool_info,
docstring=all_tool_info,
plans=plan_str,
previous_attempts=PREVIOUS_FAILED.format(
code=code, error=tool_output.text()
Expand Down Expand Up @@ -237,7 +243,17 @@ def pick_plan(
best_plan = extract_json(model(chat))
if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
return best_plan["best_plan"], tool_output_str

plan = best_plan["best_plan"]
if plan in plans and plan in tool_infos:
return plans[plan], tool_infos[plan], tool_output_str
else:
if verbosity >= 1:
_LOGGER.warning(
f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
return plans[k], tool_infos[k], tool_output_str


@traceable
Expand Down Expand Up @@ -524,6 +540,13 @@ def retrieve_tools(
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
log_progress(
{
"type": "tools",
"status": "completed",
"payload": tool_lists[list(plans.keys())[0]],
}
)
return tool_lists_unique


Expand Down Expand Up @@ -692,6 +715,14 @@ def chat_with_workflow(
self.planner,
)

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plans[list(plans.keys())[0]],
}
)

if self.verbosity >= 1 and test_multi_plan:
for p in plans:
_LOGGER.info(
Expand All @@ -705,47 +736,25 @@ def chat_with_workflow(
self.verbosity,
)

if test_multi_plan:
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)
else:
best_plan = list(plans.keys())[0]
tool_output_str = ""

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
best_plan, best_tool_info, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos,
self.coder,
code_interpreter,
test_multi_plan,
verbosity=self.verbosity,
)

if self.verbosity >= 1:
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=best_plan, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

results = write_and_test_code(
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
plan="\n-" + "\n-".join([e["instructions"] for e in best_plan]),
tool_info=best_tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
Expand All @@ -761,7 +770,7 @@ def chat_with_workflow(
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})
plan.append({"code": code, "test": test, "plan": best_plan})

execution_result = cast(Execution, results["test_result"])
self.log_progress(
Expand Down

0 comments on commit c8fc196

Please sign in to comment.