Skip to content

Commit

Permalink
added dynamic re-planning
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 10, 2024
1 parent 8d3c60d commit bf50f74
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 29 deletions.
88 changes: 59 additions & 29 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import pandas as pd
from rich.console import Console
Expand Down Expand Up @@ -33,6 +33,15 @@
_CONSOLE = Console()


def build_working_memory(working_memory: Mapping[str, List[str]]) -> Sim:
data: Mapping[str, List[str]] = {"desc": [], "doc": []}
for key, value in working_memory.items():
data["desc"].append(key)
data["doc"].append("\n".join(value))
df = pd.DataFrame(data) # type: ignore
return Sim(df, sim_key="desc")


def extract_code(code: str) -> str:
if "```python" in code:
code = code[code.find("```python") + len("```python") :]
Expand All @@ -41,12 +50,21 @@ def extract_code(code: str) -> str:


def write_plan(
user_requirements: str, tool_desc: str, model: LLM
) -> List[Dict[str, Any]]:
chat: List[Dict[str, str]],
plan: Optional[List[Dict[str, Any]]],
tool_desc: str,
model: LLM,
) -> Tuple[str, List[Dict[str, Any]]]:
# Get last user request
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")
user_requirements = chat[-1]["content"]

context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc)
plan = json.loads(model(prompt).replace("```", "").strip())
return plan["plan"] # type: ignore
prompt = PLAN.format(context=context, plan=str(plan), tool_desc=tool_desc)
chat[-1]["content"] = prompt
plan = json.loads(model.chat(chat).replace("```", "").strip())
return plan["user_req"], plan["plan"] # type: ignore


def write_code(
Expand Down Expand Up @@ -123,7 +141,7 @@ def write_and_exec_code(
user_req: str,
subtask: str,
orig_code: str,
code_writer_call: Callable,
code_writer_call: Callable[..., str],
model: LLM,
tool_info: str,
exec: Execute,
Expand Down Expand Up @@ -191,6 +209,7 @@ def run_plan(
current_test = ""
retrieved_ltm = ""
working_memory: Dict[str, List[str]] = {}

for task in active_plan:
_LOGGER.info(
f"""
Expand All @@ -209,7 +228,7 @@ def run_plan(
user_req,
task["instruction"],
current_code,
write_code if task["type"] == "code" else write_test,
write_code if task["type"] == "code" else write_test, # type: ignore
coder,
tool_info,
exec,
Expand Down Expand Up @@ -277,43 +296,55 @@ def __init__(
if "doc" not in long_term_memory.df.columns:
raise ValueError("Long term memory must have a 'doc' column.")
self.long_term_memory = long_term_memory
self.max_retries = 3
if self.verbose:
_LOGGER.setLevel(logging.INFO)

def __call__(
self,
input: Union[List[Dict[str, str]], str],
image: Optional[Union[str, Path]] = None,
plan: Optional[List[Dict[str, Any]]] = None,
) -> str:
if isinstance(input, str):
input = [{"role": "user", "content": input}]
code, _ = self.chat_with_tests(input, image)
return code
results = self.chat_with_workflow(input, image, plan)
return results["code"] # type: ignore

def chat_with_tests(
def chat_with_workflow(
self,
chat: List[Dict[str, str]],
image: Optional[Union[str, Path]] = None,
) -> Tuple[str, str]:
plan: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if len(chat) == 0:
raise ValueError("Input cannot be empty.")

user_req = chat[0]["content"]
if image is not None:
user_req += f" Image name {image}"
# append file names to all user messages
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {image}"

working_code = ""
if plan is not None:
# grab the latest working code from a previous plan
for task in plan:
if "success" in task and "code" in task and task["success"]:
working_code = task["code"]

plan = write_plan(user_req, TOOL_DESCRIPTIONS, self.planner)
user_req, plan = write_plan(chat, plan, TOOL_DESCRIPTIONS, self.planner)
_LOGGER.info(
f"""Plan:
{tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)

working_code = ""
working_test = ""
working_memory: Dict[str, List[str]] = {}
success = False
retries = 0

__import__("ipdb").set_trace()
while not success:
while not success and retries < self.max_retries:
working_code, working_test, plan, working_memory_i = run_plan(
user_req,
plan,
Expand All @@ -325,22 +356,21 @@ def chat_with_tests(
self.verbose,
)
success = all(task["success"] for task in plan)
self._working_memory.update(working_memory_i)
working_memory.update(working_memory_i)

if not success:
# TODO: ask for feedback and replan
# return to user and request feedback
break

return working_code, working_test
retries += 1

@property
def working_memory(self) -> Sim:
data: Dict[str, List[str]] = {"desc": [], "doc": []}
for key, value in self._working_memory.items():
data["desc"].append(key)
data["doc"].append("\n".join(value))
df = pd.DataFrame(data)
return Sim(df, sim_key="desc")
return {
"code": working_code,
"test": working_test,
"success": success,
"working_memory": build_working_memory(working_memory),
"plan": plan,
}

def log_progress(self, description: str) -> None:
pass
2 changes: 2 additions & 0 deletions vision_agent/agent/vision_agent_v2_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@
- For each subtask, you should provide a short instruction on what to do. Ensure the subtasks are large enough to be meaningful, encompassing multiple lines of code.
- You do not need to have the agent rewrite any tool functionality you already have, you should instead instruct it to utilize one or more of those tools in each subtask.
- You can have agents either write coding tasks, to code some functionality or testing tasks to test previous functionality.
- If a current plan exists, examine each item in the plan to determine if it was successful. If there was an item that failed, i.e. 'success': False, then you should rewrite that item and all subsequent items to ensure that the rewritten plan is successful.
Output a list of jsons in the following format:
```json
{{
"user_req": str, # "a summarized version of the user requirement"
"plan":
[
{{
Expand Down

0 comments on commit bf50f74

Please sign in to comment.