Skip to content

Commit c8fc196

Browse files
authored
fix: add back some logging for ui (#176)
* fix: add back some logging for ui * simplifier * resolve lint error * fix * fix type error * fix type error * fix type error * remove check
1 parent 6c173d9 commit c8fc196

File tree

1 file changed

+49
-40
lines changed

1 file changed

+49
-40
lines changed

vision_agent/agent/vision_agent.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,25 @@ def write_plans(
172172
def pick_plan(
173173
chat: List[Message],
174174
plans: Dict[str, Any],
175-
tool_info: str,
175+
tool_infos: Dict[str, str],
176176
model: LMM,
177177
code_interpreter: CodeInterpreter,
178+
test_multi_plan: bool,
178179
verbosity: int = 0,
179180
max_retries: int = 3,
180-
) -> Tuple[str, str]:
181+
) -> Tuple[Any, str, str]:
182+
if not test_multi_plan:
183+
k = list(plans.keys())[0]
184+
return plans[k], tool_infos[k], ""
185+
186+
all_tool_info = tool_infos["all"]
181187
chat = copy.deepcopy(chat)
182188
if chat[-1]["role"] != "user":
183189
raise ValueError("Last chat message must be from the user.")
184190

185191
plan_str = format_plans(plans)
186192
prompt = TEST_PLANS.format(
187-
docstring=tool_info, plans=plan_str, previous_attempts=""
193+
docstring=all_tool_info, plans=plan_str, previous_attempts=""
188194
)
189195

190196
code = extract_code(model(prompt))
@@ -201,7 +207,7 @@ def pick_plan(
201207
count = 0
202208
while (not tool_output.success or tool_output_str == "") and count < max_retries:
203209
prompt = TEST_PLANS.format(
204-
docstring=tool_info,
210+
docstring=all_tool_info,
205211
plans=plan_str,
206212
previous_attempts=PREVIOUS_FAILED.format(
207213
code=code, error=tool_output.text()
@@ -237,7 +243,17 @@ def pick_plan(
237243
best_plan = extract_json(model(chat))
238244
if verbosity >= 1:
239245
_LOGGER.info(f"Best plan:\n{best_plan}")
240-
return best_plan["best_plan"], tool_output_str
246+
247+
plan = best_plan["best_plan"]
248+
if plan in plans and plan in tool_infos:
249+
return plans[plan], tool_infos[plan], tool_output_str
250+
else:
251+
if verbosity >= 1:
252+
_LOGGER.warning(
253+
f"Best plan {plan} not found in plans or tool_infos. Using the first plan and tool info."
254+
)
255+
k = list(plans.keys())[0]
256+
return plans[k], tool_infos[k], tool_output_str
241257

242258

243259
@traceable
@@ -524,6 +540,13 @@ def retrieve_tools(
524540
)
525541
all_tools = "\n\n".join(set(tool_info))
526542
tool_lists_unique["all"] = all_tools
543+
log_progress(
544+
{
545+
"type": "tools",
546+
"status": "completed",
547+
"payload": tool_lists[list(plans.keys())[0]],
548+
}
549+
)
527550
return tool_lists_unique
528551

529552

@@ -692,6 +715,14 @@ def chat_with_workflow(
692715
self.planner,
693716
)
694717

718+
self.log_progress(
719+
{
720+
"type": "plans",
721+
"status": "completed",
722+
"payload": plans[list(plans.keys())[0]],
723+
}
724+
)
725+
695726
if self.verbosity >= 1 and test_multi_plan:
696727
for p in plans:
697728
_LOGGER.info(
@@ -705,47 +736,25 @@ def chat_with_workflow(
705736
self.verbosity,
706737
)
707738

708-
if test_multi_plan:
709-
best_plan, tool_output_str = pick_plan(
710-
int_chat,
711-
plans,
712-
tool_infos["all"],
713-
self.coder,
714-
code_interpreter,
715-
verbosity=self.verbosity,
716-
)
717-
else:
718-
best_plan = list(plans.keys())[0]
719-
tool_output_str = ""
720-
721-
if best_plan in plans and best_plan in tool_infos:
722-
plan_i = plans[best_plan]
723-
tool_info = tool_infos[best_plan]
724-
else:
725-
if self.verbosity >= 1:
726-
_LOGGER.warning(
727-
f"Best plan {best_plan} not found in plans or tool_infos. Using the first plan and tool info."
728-
)
729-
k = list(plans.keys())[0]
730-
plan_i = plans[k]
731-
tool_info = tool_infos[k]
732-
733-
self.log_progress(
734-
{
735-
"type": "plans",
736-
"status": "completed",
737-
"payload": plan_i,
738-
}
739+
best_plan, best_tool_info, tool_output_str = pick_plan(
740+
int_chat,
741+
plans,
742+
tool_infos,
743+
self.coder,
744+
code_interpreter,
745+
test_multi_plan,
746+
verbosity=self.verbosity,
739747
)
748+
740749
if self.verbosity >= 1:
741750
_LOGGER.info(
742-
f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
751+
f"Picked best plan:\n{tabulate(tabular_data=best_plan, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
743752
)
744753

745754
results = write_and_test_code(
746755
chat=[{"role": c["role"], "content": c["content"]} for c in int_chat],
747-
plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]),
748-
tool_info=tool_info,
756+
plan="\n-" + "\n-".join([e["instructions"] for e in best_plan]),
757+
tool_info=best_tool_info,
749758
tool_output=tool_output_str,
750759
tool_utils=T.UTILITIES_DOCSTRING,
751760
working_memory=working_memory,
@@ -761,7 +770,7 @@ def chat_with_workflow(
761770
code = cast(str, results["code"])
762771
test = cast(str, results["test"])
763772
working_memory.extend(results["working_memory"]) # type: ignore
764-
plan.append({"code": code, "test": test, "plan": plan_i})
773+
plan.append({"code": code, "test": test, "plan": best_plan})
765774

766775
execution_result = cast(Execution, results["test_result"])
767776
self.log_progress(

0 commit comments

Comments
 (0)