From 4e3cfaad0a3dae1908ab64725c96ec948e94f1f9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 10 Oct 2024 17:23:54 -0700 Subject: [PATCH] fix issue if plan format is incorrect --- vision_agent/agent/agent_utils.py | 1 + vision_agent/agent/vision_agent_coder.py | 2 +- vision_agent/agent/vision_agent_planner.py | 33 +++++++++++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index 9b7ea02a..e07ec619 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -13,6 +13,7 @@ logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) _CONSOLE = Console() +_MAX_TABULATE_COL_WIDTH = 80 def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 345c4552..bc2de2a3 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -11,6 +11,7 @@ import vision_agent.tools as T from vision_agent.agent.agent import Agent from vision_agent.agent.agent_utils import ( + _MAX_TABULATE_COL_WIDTH, DefaultImports, extract_code, extract_json, @@ -46,7 +47,6 @@ logging.basicConfig(stream=sys.stdout) WORKSPACE = Path(os.getenv("WORKSPACE", "")) _LOGGER = logging.getLogger(__name__) -_MAX_TABULATE_COL_WIDTH = 80 def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str: diff --git a/vision_agent/agent/vision_agent_planner.py b/vision_agent/agent/vision_agent_planner.py index 581431a5..fb746e86 100644 --- a/vision_agent/agent/vision_agent_planner.py +++ b/vision_agent/agent/vision_agent_planner.py @@ -5,10 +5,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from pydantic import BaseModel +from tabulate import tabulate import vision_agent.tools as T from vision_agent.agent import Agent from vision_agent.agent.agent_utils import ( + _MAX_TABULATE_COL_WIDTH, DefaultImports, extract_code, extract_json, @@ -91,6 +93,18 @@ def retrieve_tools( return tool_lists_unique +def _check_plan_format(plan: Dict[str, Any]) -> bool: + if not isinstance(plan, dict): + return False + + for k in plan: + if "thoughts" not in plan[k] or "instructions" not in plan[k]: + return False + if not isinstance(plan[k]["instructions"], list): + return False + return True + + def write_plans( chat: List[Message], tool_desc: str, working_memory: str, model: LMM ) -> Dict[str, Any]: @@ -106,7 +120,16 @@ def write_plans( feedback=working_memory, ) chat[-1]["content"] = prompt - return extract_json(model(chat, stream=False)) # type: ignore + plans = extract_json(model(chat, stream=False)) # type: ignore + + count = 0 + while not _check_plan_format(plans) and count < 3: + _LOGGER.info(f"Invalid plan format. Retrying.") + plans = extract_json(model(chat, stream=False)) # type: ignore + count += 1 + if count == 3: + raise ValueError("Failed to generate valid plans after 3 attempts.") + return plans def write_and_exec_plan_tests( @@ -404,6 +427,14 @@ def generate_plan( format_memory(working_memory), self.planner, ) + if self.verbosity >= 1: + for plan in plans: + plan_fixed = [ + {"instructions": e} for e in plans[plan]["instructions"] + ] + _LOGGER.info( + f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + ) tool_docs = retrieve_tools( plans,