Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add back some logging for ui #176

Merged
merged 8 commits into from
Jul 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,25 @@ def write_plans(
def pick_plan(
chat: List[Message],
plans: Dict[str, Any],
tool_info: str,
tool_infos: Dict[str, str],
model: LMM,
code_interpreter: CodeInterpreter,
test_multi_plan: bool,
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[str, str]:
) -> Tuple[Any, str, str]:
if not test_multi_plan:
k = list(plans.keys())[0]
return plans[k], tool_infos[k], ""

all_tool_info = tool_infos["all"]
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

plan_str = format_plans(plans)
prompt = TEST_PLANS.format(
docstring=tool_info, plans=plan_str, previous_attempts=""
docstring=all_tool_info, plans=plan_str, previous_attempts=""
)

code = extract_code(model(prompt))
Expand All @@ -201,7 +207,7 @@ def pick_plan(
count = 0
while (not tool_output.success or tool_output_str == "") and count < max_retries:
prompt = TEST_PLANS.format(
docstring=tool_info,
docstring=all_tool_info,
plans=plan_str,
previous_attempts=PREVIOUS_FAILED.format(
code=code, error=tool_output.text()
Expand Down Expand Up @@ -237,7 +243,17 @@ def pick_plan(
best_plan = extract_json(model(chat))
if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
return best_plan["best_plan"], tool_output_str

plan = best_plan["best_plan"]
if plan in plans and plan in tool_infos:
return plans[plan], tool_infos[plan], tool_output_str
else:
if verbosity >= 1:
_LOGGER.warning(
f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
return plans[k], tool_infos[k], tool_output_str


@traceable
Expand Down Expand Up @@ -524,6 +540,13 @@ def retrieve_tools(
)
all_tools = "\n\n".join(set(tool_info))
tool_lists_unique["all"] = all_tools
log_progress(
{
"type": "tools",
"status": "completed",
"payload": tool_lists[list(plans.keys())[0]],
}
)
return tool_lists_unique


Expand Down Expand Up @@ -692,6 +715,14 @@ def chat_with_workflow(
self.planner,
)

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plans[list(plans.keys())[0]],
}
)

if self.verbosity >= 1 and test_multi_plan:
for p in plans:
_LOGGER.info(
Expand All @@ -705,47 +736,25 @@ def chat_with_workflow(
self.verbosity,
)

if test_multi_plan:
best_plan, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos["all"],
self.coder,
code_interpreter,
verbosity=self.verbosity,
)
else:
best_plan = list(plans.keys())[0]
tool_output_str = ""

if best_plan in plans and best_plan in tool_infos:
plan_i = plans[best_plan]
tool_info = tool_infos[best_plan]
else:
if self.verbosity >= 1:
_LOGGER.warning(
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
)
k = list(plans.keys())[0]
plan_i = plans[k]
tool_info = tool_infos[k]

self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plan_i,
}
best_plan, best_tool_info, tool_output_str = pick_plan(
int_chat,
plans,
tool_infos,
self.coder,
code_interpreter,
test_multi_plan,
verbosity=self.verbosity,
)

if self.verbosity >= 1:
_LOGGER.info(
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
f"Picked best plan:\n{tabulate(tabular_data=best_plan, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

results = write_and_test_code(
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
tool_info=tool_info,
plan="\n-" + "\n-".join([e["instructions"] for e in best_plan]),
tool_info=best_tool_info,
tool_output=tool_output_str,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
Expand All @@ -761,7 +770,7 @@ def chat_with_workflow(
code = cast(str, results["code"])
test = cast(str, results["test"])
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})
plan.append({"code": code, "test": test, "plan": best_plan})

execution_result = cast(Execution, results["test_result"])
self.log_progress(
Expand Down
Loading