Skip to content

Commit

Permalink
added long term memory
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 9, 2024
1 parent 35db97e commit 8d3c60d
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 25 deletions.
76 changes: 61 additions & 15 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
from rich.console import Console
from rich.syntax import Syntax
from tabulate import tabulate
Expand All @@ -20,6 +21,7 @@
TEST,
USER_REQ_CONTEXT,
USER_REQ_SUBTASK_CONTEXT,
USER_REQ_SUBTASK_WM_CONTEXT,
)
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.tools.tools_v2 import TOOL_DESCRIPTIONS, TOOLS_DF
Expand Down Expand Up @@ -48,11 +50,16 @@ def write_plan(


def write_code(
user_req: str, subtask: str, tool_info: str, code: str, model: LLM
user_req: str,
subtask: str,
working_memory: str,
tool_info: str,
code: str,
model: LLM,
) -> str:
prompt = CODE.format(
context=USER_REQ_SUBTASK_CONTEXT.format(
user_requirement=user_req, subtask=subtask
context=USER_REQ_SUBTASK_WM_CONTEXT.format(
user_requirement=user_req, working_memory=working_memory, subtask=subtask
),
tool_info=tool_info,
code=code,
Expand Down Expand Up @@ -83,14 +90,24 @@ def write_test(
return extract_code(code)


def debug_code(sub_task: str, working_memory: List[str], model: LLM) -> Tuple[str, str]:
def debug_code(
user_req: str,
subtask: str,
retrieved_ltm: str,
working_memory: str,
model: LLM,
) -> Tuple[str, str]:
# Make debug model output JSON
if hasattr(model, "kwargs"):
model.kwargs["response_format"] = {"type": "json_object"}
prompt = DEBUG.format(
debug_example=DEBUG_EXAMPLE,
context=USER_REQ_CONTEXT.format(user_requirement=sub_task),
previous_impl="\n".join(working_memory),
context=USER_REQ_SUBTASK_WM_CONTEXT.format(
user_requirement=user_req,
subtask=subtask,
working_memory=retrieved_ltm,
),
previous_impl=working_memory,
)
messages = [
{"role": "system", "content": DEBUG_SYS_MSG},
Expand All @@ -110,15 +127,17 @@ def write_and_exec_code(
model: LLM,
tool_info: str,
exec: Execute,
retrieved_ltm: str,
max_retry: int = 3,
verbose: bool = False,
) -> Tuple[bool, str, str, Dict[str, List[str]]]:
success = False
counter = 0
reflection = ""

# TODO: add working memory to code_writer_call and debug_code
code = code_writer_call(user_req, subtask, tool_info, orig_code, model)
code = code_writer_call(
user_req, subtask, retrieved_ltm, tool_info, orig_code, model
)
success, result = exec.run_isolation(code)
working_memory: Dict[str, List[str]] = {}
while not success and counter < max_retry:
Expand All @@ -136,7 +155,9 @@ def write_and_exec_code(
PREV_CODE_CONTEXT.format(code=code, result=result)
)

code, reflection = debug_code(subtask, working_memory[subtask], model)
code, reflection = debug_code(
user_req, subtask, retrieved_ltm, "\n".join(working_memory[subtask]), model
)
success, result = exec.run_isolation(code)
counter += 1
if verbose:
Expand All @@ -148,7 +169,7 @@ def write_and_exec_code(
if success:
working_memory[subtask].append(
PREV_CODE_CONTEXT_WITH_REFLECTION.format(
code=code, result=result, reflection=reflection
reflection=reflection, code=code, result=result
)
)

Expand All @@ -162,12 +183,14 @@ def run_plan(
exec: Execute,
code: str,
tool_recommender: Sim,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
) -> Tuple[str, str, List[Dict[str, Any]], Dict[str, List[str]]]:
active_plan = [e for e in plan if "success" not in e or not e["success"]]
working_memory: Dict[str, List[str]] = {}
current_code = code
current_test = ""
retrieved_ltm = ""
working_memory: Dict[str, List[str]] = {}
for task in active_plan:
_LOGGER.info(
f"""
Expand All @@ -176,22 +199,29 @@ def run_plan(
tool_info = "\n".join(
[e["doc"] for e in tool_recommender.top_k(task["instruction"])]
)
success, code, result, task_memory = write_and_exec_code(

if long_term_memory is not None:
retrieved_ltm = "\n".join(
[e["doc"] for e in long_term_memory.top_k(task["instruction"], 1)]
)

success, code, result, working_memory_i = write_and_exec_code(
user_req,
task["instruction"],
current_code,
write_code if task["type"] == "code" else write_test,
coder,
tool_info,
exec,
retrieved_ltm,
verbose,
)
if task["type"] == "code":
current_code = code
else:
current_test = code

working_memory.update(task_memory)
working_memory.update(working_memory_i)

if verbose:
_CONSOLE.print(
Expand Down Expand Up @@ -231,6 +261,7 @@ def __init__(
self,
timeout: int = 600,
tool_recommender: Optional[Sim] = None,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
) -> None:
self.planner = OpenAILLM(temperature=0.1, json_mode=True)
Expand All @@ -241,6 +272,11 @@ def __init__(
else:
self.tool_recommender = tool_recommender
self.verbose = verbose
self._working_memory: Dict[str, List[str]] = {}
if long_term_memory is not None:
if "doc" not in long_term_memory.df.columns:
raise ValueError("Long term memory must have a 'doc' column.")
self.long_term_memory = long_term_memory
if self.verbose:
_LOGGER.setLevel(logging.INFO)

Expand Down Expand Up @@ -271,12 +307,12 @@ def chat_with_tests(
f"""Plan:
{tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
working_memory: Dict[str, List[str]] = {}

working_code = ""
working_test = ""
success = False

__import__("ipdb").set_trace()
while not success:
working_code, working_test, plan, working_memory_i = run_plan(
user_req,
Expand All @@ -285,16 +321,26 @@ def chat_with_tests(
self.exec,
working_code,
self.tool_recommender,
self.long_term_memory,
self.verbose,
)
success = all(task["success"] for task in plan)
working_memory.update(working_memory_i)
self._working_memory.update(working_memory_i)

if not success:
# TODO: ask for feedback and replan
break

return working_code, working_test

@property
def working_memory(self) -> Sim:
data: Dict[str, List[str]] = {"desc": [], "doc": []}
for key, value in self._working_memory.items():
data["desc"].append(key)
data["doc"].append("\n".join(value))
df = pd.DataFrame(data)
return Sim(df, sim_key="desc")

def log_progress(self, description: str) -> None:
pass
33 changes: 23 additions & 10 deletions vision_agent/agent/vision_agent_v2_prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
USER_REQ_CONTEXT = """
## User Requirement
{user_requirement}
"""

USER_REQ_SUBTASK_CONTEXT = """
## User Requirement
{user_requirement}
Expand All @@ -6,11 +11,16 @@
{subtask}
"""

USER_REQ_CONTEXT = """
USER_REQ_SUBTASK_WM_CONTEXT = """
## User Requirement
{user_requirement}
"""
## Current Subtask
{subtask}
## Previous Task
{working_memory}
"""

PLAN = """
# Context
Expand Down Expand Up @@ -61,8 +71,9 @@
{code}
# Constraints
- Write a function that accomplishes the User Requirement. You are supplied code from a previous task, feel free to copy over that code into your own implementation if you need it.
- Always prioritize using pre-defined tools or code for the same functionality. You have access to all these tools through the `from vision_agent.tools.tools_v2 import *` import.
- Write a function that accomplishes the 'User Requirement'. You are supplied code from a previous task under 'Previous Code', feel free to copy over that code into your own implementation if you need it.
- Always prioritize using pre-defined tools or code for the same functionality from 'Tool Info for Current Subtask'. You have access to all these tools through the `from vision_agent.tools.tools_v2 import *` import.
- You may recieve previous trials and errors under 'Previous Task', this is code, output and reflections from previous tasks. You can use these to avoid running in to the same issues when writing your code.
- Write clean, readable, and well-documented code.
# Output
Expand Down Expand Up @@ -102,6 +113,7 @@ def add(a: int, b: int) -> int:


PREV_CODE_CONTEXT = """
[previous impl]
```python
{code}
```
Expand All @@ -112,18 +124,20 @@ def add(a: int, b: int) -> int:


PREV_CODE_CONTEXT_WITH_REFLECTION = """
[reflection on previous impl]
{reflection}
[new impl]
```python
{code}
```
[previous output]
[new output]
{result}
[reflection on previous impl]
{reflection}
"""


# don't need [previous impl] because it will come from PREV_CODE_CONTEXT or PREV_CODE_CONTEXT_WITH_REFLECTION
DEBUG = """
[example]
Here is an example of debugging with reflection.
Expand All @@ -133,7 +147,6 @@ def add(a: int, b: int) -> int:
[context]
{context}
[previous impl]
{previous_impl}
[instruction]
Expand All @@ -158,7 +171,7 @@ def add(a: int, b: int) -> int:
{code}
# Constraints
- Write code to test the functionality of the provided code according to the Current Subtask. If you cannot test the code, then write code to visualize the result by calling the code.
- Write code to test the functionality of the provided code according to the 'Current Subtask'. If you cannot test the code, then write code to visualize the result by calling the code.
- Always prioritize using pre-defined tools for the same functionality.
- Write clean, readable, and well-documented code.
Expand Down

0 comments on commit 8d3c60d

Please sign in to comment.