diff --git a/vision_agent/agent/vision_agent_v2.py b/vision_agent/agent/vision_agent_v2.py index c2a875ce..dc5d9626 100644 --- a/vision_agent/agent/vision_agent_v2.py +++ b/vision_agent/agent/vision_agent_v2.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union import pandas as pd from rich.console import Console @@ -33,6 +33,15 @@ _CONSOLE = Console() +def build_working_memory(working_memory: Mapping[str, List[str]]) -> Sim: + data: Mapping[str, List[str]] = {"desc": [], "doc": []} + for key, value in working_memory.items(): + data["desc"].append(key) + data["doc"].append("\n".join(value)) + df = pd.DataFrame(data) # type: ignore + return Sim(df, sim_key="desc") + + def extract_code(code: str) -> str: if "```python" in code: code = code[code.find("```python") + len("```python") :] @@ -41,12 +50,21 @@ def extract_code(code: str) -> str: def write_plan( - user_requirements: str, tool_desc: str, model: LLM -) -> List[Dict[str, Any]]: + chat: List[Dict[str, str]], + plan: Optional[List[Dict[str, Any]]], + tool_desc: str, + model: LLM, +) -> Tuple[str, List[Dict[str, Any]]]: + # Get last user request + if chat[-1]["role"] != "user": + raise ValueError("Last chat message must be from the user.") + user_requirements = chat[-1]["content"] + context = USER_REQ_CONTEXT.format(user_requirement=user_requirements) - prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc) - plan = json.loads(model(prompt).replace("```", "").strip()) - return plan["plan"] # type: ignore + prompt = PLAN.format(context=context, plan=str(plan), tool_desc=tool_desc) + chat[-1]["content"] = prompt + plan = json.loads(model.chat(chat).replace("```", "").strip()) + return plan["user_req"], plan["plan"] # type: ignore def write_code( @@ -123,7 +141,7 @@ def write_and_exec_code( user_req: str, subtask: str, orig_code: str, - code_writer_call: Callable, + code_writer_call: Callable[..., str], model: LLM, tool_info: str, exec: Execute, @@ -191,6 +209,7 @@ def run_plan( current_test = "" retrieved_ltm = "" working_memory: Dict[str, List[str]] = {} + for task in active_plan: _LOGGER.info( f""" @@ -209,7 +228,7 @@ def run_plan( user_req, task["instruction"], current_code, - write_code if task["type"] == "code" else write_test, + write_code if task["type"] == "code" else write_test, # type: ignore coder, tool_info, exec, @@ -277,6 +296,7 @@ def __init__( if "doc" not in long_term_memory.df.columns: raise ValueError("Long term memory must have a 'doc' column.") self.long_term_memory = long_term_memory + self.max_retries = 3 if self.verbose: _LOGGER.setLevel(logging.INFO) @@ -284,36 +304,47 @@ def __call__( self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None, + plan: Optional[List[Dict[str, Any]]] = None, ) -> str: if isinstance(input, str): input = [{"role": "user", "content": input}] - code, _ = self.chat_with_tests(input, image) - return code + results = self.chat_with_workflow(input, image, plan) + return results["code"] # type: ignore - def chat_with_tests( + def chat_with_workflow( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None, - ) -> Tuple[str, str]: + plan: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: if len(chat) == 0: raise ValueError("Input cannot be empty.") - user_req = chat[0]["content"] if image is not None: - user_req += f" Image name {image}" + # append file names to all user messages + for chat_i in chat: + if chat_i["role"] == "user": + chat_i["content"] += f" Image name {image}" + + working_code = "" + if plan is not None: + # grab the latest working code from a previous plan + for task in plan: + if "success" in task and "code" in task and task["success"]: + working_code = task["code"] - plan = write_plan(user_req, TOOL_DESCRIPTIONS, self.planner) + user_req, plan = write_plan(chat, plan, TOOL_DESCRIPTIONS, self.planner) _LOGGER.info( f"""Plan: {tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}""" ) - working_code = "" working_test = "" + working_memory: Dict[str, List[str]] = {} success = False + retries = 0 - __import__("ipdb").set_trace() - while not success: + while not success and retries < self.max_retries: working_code, working_test, plan, working_memory_i = run_plan( user_req, plan, @@ -325,22 +356,21 @@ def chat_with_tests( self.verbose, ) success = all(task["success"] for task in plan) - self._working_memory.update(working_memory_i) + working_memory.update(working_memory_i) if not success: - # TODO: ask for feedback and replan + # return to user and request feedback break - return working_code, working_test + retries += 1 - @property - def working_memory(self) -> Sim: - data: Dict[str, List[str]] = {"desc": [], "doc": []} - for key, value in self._working_memory.items(): - data["desc"].append(key) - data["doc"].append("\n".join(value)) - df = pd.DataFrame(data) - return Sim(df, sim_key="desc") + return { + "code": working_code, + "test": working_test, + "success": success, + "working_memory": build_working_memory(working_memory), + "plan": plan, + } def log_progress(self, description: str) -> None: pass diff --git a/vision_agent/agent/vision_agent_v2_prompt.py b/vision_agent/agent/vision_agent_v2_prompt.py index 6965aa7b..4003b4df 100644 --- a/vision_agent/agent/vision_agent_v2_prompt.py +++ b/vision_agent/agent/vision_agent_v2_prompt.py @@ -37,11 +37,13 @@ - For each subtask, you should provide a short instruction on what to do. Ensure the subtasks are large enough to be meaningful, encompassing multiple lines of code. - You do not need to have the agent rewrite any tool functionality you already have, you should instead instruct it to utilize one or more of those tools in each subtask. - You can have agents either write coding tasks, to code some functionality or testing tasks to test previous functionality. +- If a current plan exists, examine each item in the plan to determine if it was successful. If there was an item that failed, i.e. 'success': False, then you should rewrite that item and all subsequent items to ensure that the rewritten plan is successful. Output a list of jsons in the following format: ```json {{ + "user_req": str, # "a summarized version of the user requirement" "plan": [ {{