Skip to content

Commit

Permalink
fix issue if plan format is incorrect
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 11, 2024
1 parent a740a28 commit 4e3cfaa
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
1 change: 1 addition & 0 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_CONSOLE = Console()
_MAX_TABULATE_COL_WIDTH = 80


def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import vision_agent.tools as T
from vision_agent.agent.agent import Agent
from vision_agent.agent.agent_utils import (
_MAX_TABULATE_COL_WIDTH,
DefaultImports,
extract_code,
extract_json,
Expand Down Expand Up @@ -46,7 +47,6 @@
logging.basicConfig(stream=sys.stdout)
WORKSPACE = Path(os.getenv("WORKSPACE", ""))
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80


def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str:
Expand Down
33 changes: 32 additions & 1 deletion vision_agent/agent/vision_agent_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

from pydantic import BaseModel
from tabulate import tabulate

import vision_agent.tools as T
from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import (
_MAX_TABULATE_COL_WIDTH,
DefaultImports,
extract_code,
extract_json,
Expand Down Expand Up @@ -91,6 +93,18 @@ def retrieve_tools(
return tool_lists_unique


def _check_plan_format(plan: Dict[str, Any]) -> bool:
if not isinstance(plan, dict):
return False

for k in plan:
if "thoughts" not in plan[k] or "instructions" not in plan[k]:
return False
if not isinstance(plan[k]["instructions"], list):
return False
return True


def write_plans(
chat: List[Message], tool_desc: str, working_memory: str, model: LMM
) -> Dict[str, Any]:
Expand All @@ -106,7 +120,16 @@ def write_plans(
feedback=working_memory,
)
chat[-1]["content"] = prompt
return extract_json(model(chat, stream=False)) # type: ignore
plans = extract_json(model(chat, stream=False)) # type: ignore

count = 0
while not _check_plan_format(plans) and count < 3:
_LOGGER.info(f"Invalid plan format. Retrying.")
plans = extract_json(model(chat, stream=False)) # type: ignore
count += 1
if count == 3:
raise ValueError("Failed to generate valid plans after 3 attempts.")
return plans


def write_and_exec_plan_tests(
Expand Down Expand Up @@ -404,6 +427,14 @@ def generate_plan(
format_memory(working_memory),
self.planner,
)
if self.verbosity >= 1:
for plan in plans:
plan_fixed = [
{"instructions": e} for e in plans[plan]["instructions"]
]
_LOGGER.info(
f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)

tool_docs = retrieve_tools(
plans,
Expand Down

0 comments on commit 4e3cfaa

Please sign in to comment.