diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 40b69257..65a46a6f 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -87,8 +87,8 @@ def format_memory(memory: List[Dict[str, str]]) -> str: def format_plans(plans: Dict[str, Any]) -> str: plan_str = "" for k, v in plans.items(): - plan_str += f"{k}:\n" - plan_str += "-" + "\n-".join([e["instructions"] for e in v]) + plan_str += "\n" + f"{k}: {v['thoughts']}\n" + plan_str += " -" + "\n -".join([e for e in v["instructions"]]) return plan_str @@ -229,9 +229,7 @@ def pick_plan( "status": "completed" if tool_output.success else "failed", } ) - tool_output_str = "" - if len(tool_output.logs.stdout) > 0: - tool_output_str = tool_output.logs.stdout[0] + tool_output_str = tool_output.text().strip() if verbosity == 2: _print_code("Code and test after attempted fix:", code) @@ -255,14 +253,15 @@ def pick_plan( count = 0 best_plan = None - while best_plan is None or count < max_retries: + while best_plan is None and count < max_retries: try: best_plan = extract_json(model(chat, stream=False)) # type: ignore except JSONDecodeError as _: + _LOGGER.exception("Error while extracting JSON during picking best plan") pass count += 1 - if count == max_retries: + if best_plan is None: best_plan = {"best_plan": list(plans.keys())[0]} if verbosity >= 1: @@ -537,7 +536,7 @@ def _print_code(title: str, code: str, test: Optional[str] = None) -> None: def retrieve_tools( - plans: Dict[str, List[Dict[str, str]]], + plans: Dict[str, Dict[str, Any]], tool_recommender: Sim, log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, @@ -554,8 +553,8 @@ def retrieve_tools( tool_lists: Dict[str, List[Dict[str, str]]] = {} for k, plan in plans.items(): tool_lists[k] = [] - for task in plan: - tools = tool_recommender.top_k(task["instructions"], k=2, thresh=0.3) + for task in plan["instructions"]: + tools = tool_recommender.top_k(task, k=2, thresh=0.3) tool_info.extend([e["doc"] for e in tools]) tool_desc.extend([e["desc"] for e in tools]) tool_lists[k].extend( @@ -749,14 +748,7 @@ def chat_with_workflow( if self.verbosity >= 1: for p in plans: # tabulate will fail if the keys are not the same for all elements - p_fixed = [ - { - "instructions": ( - e["instructions"] if "instructions" in e else "" - ) - } - for e in plans[p] - ] + p_fixed = [{"instructions": e} for e in plans[p]["instructions"]] _LOGGER.info( f"\n{tabulate(tabular_data=p_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) @@ -805,13 +797,14 @@ def chat_with_workflow( ) if self.verbosity >= 1: + plan_i_fixed = [{"instructions": e} for e in plan_i["instructions"]] _LOGGER.info( - f"Picked best plan:\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + f"Picked best plan:\n{tabulate(tabular_data=plan_i_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" ) results = write_and_test_code( chat=[{"role": c["role"], "content": c["content"]} for c in int_chat], - plan="\n-" + "\n-".join([e["instructions"] for e in plan_i]), + plan=f"\n{plan_i['thoughts']}\n-" + "\n-".join([e for e in plan_i["instructions"]]), tool_info=tool_info, tool_output=tool_output_str, tool_utils=T.UTILITIES_DOCSTRING,