Skip to content

Commit

Permalink
Add Long Term Memory and Feedback (#80)
Browse files Browse the repository at this point in the history
* fixed save and load

* added long term memory

* added dynamic re-planning

* add gpt-4o

* update tests

* update tests

* fixed exit loop early

* add some extra parsing for code snippets

* fix formatting

* fix typing error
  • Loading branch information
dillonalaird authored May 13, 2024
1 parent d0d83f8 commit d6fd63e
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 52 deletions.
8 changes: 4 additions & 4 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_generate_with_mock(openai_llm_mock): # noqa: F811
response = llm.generate("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo",
model="gpt-4o",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand All @@ -31,7 +31,7 @@ def test_chat_with_mock(openai_llm_mock): # noqa: F811
response = llm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo",
model="gpt-4o",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand All @@ -44,14 +44,14 @@ def test_call_with_mock(openai_llm_mock): # noqa: F811
response = llm("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo",
model="gpt-4o",
messages=[{"role": "user", "content": "test prompt"}],
)

response = llm([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_with(
model="gpt-4-turbo",
model="gpt-4o",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand Down
146 changes: 112 additions & 34 deletions vision_agent/agent/vision_agent_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import pandas as pd
from rich.console import Console
from rich.syntax import Syntax
from tabulate import tabulate
Expand All @@ -20,6 +21,7 @@
TEST,
USER_REQ_CONTEXT,
USER_REQ_SUBTASK_CONTEXT,
USER_REQ_SUBTASK_WM_CONTEXT,
)
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.tools.tools_v2 import TOOL_DESCRIPTIONS, TOOLS_DF
Expand All @@ -31,28 +33,53 @@
_CONSOLE = Console()


def build_working_memory(working_memory: Mapping[str, List[str]]) -> Sim:
data: Mapping[str, List[str]] = {"desc": [], "doc": []}
for key, value in working_memory.items():
data["desc"].append(key)
data["doc"].append("\n".join(value))
df = pd.DataFrame(data) # type: ignore
return Sim(df, sim_key="desc")


def extract_code(code: str) -> str:
if "```python" in code:
code = code[code.find("```python") + len("```python") :]
code = code[: code.find("```")]
if code.startswith("python\n"):
code = code[len("python\n") :]
return code


def write_plan(
user_requirements: str, tool_desc: str, model: LLM
) -> List[Dict[str, Any]]:
chat: List[Dict[str, str]],
plan: Optional[List[Dict[str, Any]]],
tool_desc: str,
model: LLM,
) -> Tuple[str, List[Dict[str, Any]]]:
# Get last user request
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")
user_requirements = chat[-1]["content"]

context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc)
plan = json.loads(model(prompt).replace("```", "").strip())
return plan["plan"] # type: ignore
prompt = PLAN.format(context=context, plan=str(plan), tool_desc=tool_desc)
chat[-1]["content"] = prompt
plan = json.loads(model.chat(chat).replace("```", "").strip())
return plan["user_req"], plan["plan"] # type: ignore


def write_code(
user_req: str, subtask: str, tool_info: str, code: str, model: LLM
user_req: str,
subtask: str,
working_memory: str,
tool_info: str,
code: str,
model: LLM,
) -> str:
prompt = CODE.format(
context=USER_REQ_SUBTASK_CONTEXT.format(
user_requirement=user_req, subtask=subtask
context=USER_REQ_SUBTASK_WM_CONTEXT.format(
user_requirement=user_req, working_memory=working_memory, subtask=subtask
),
tool_info=tool_info,
code=code,
Expand All @@ -66,7 +93,7 @@ def write_code(


def write_test(
user_req: str, subtask: str, tool_info: str, code: str, model: LLM
user_req: str, subtask: str, tool_info: str, _: str, code: str, model: LLM
) -> str:
prompt = TEST.format(
context=USER_REQ_SUBTASK_CONTEXT.format(
Expand All @@ -83,14 +110,24 @@ def write_test(
return extract_code(code)


def debug_code(sub_task: str, working_memory: List[str], model: LLM) -> Tuple[str, str]:
def debug_code(
user_req: str,
subtask: str,
retrieved_ltm: str,
working_memory: str,
model: LLM,
) -> Tuple[str, str]:
# Make debug model output JSON
if hasattr(model, "kwargs"):
model.kwargs["response_format"] = {"type": "json_object"}
prompt = DEBUG.format(
debug_example=DEBUG_EXAMPLE,
context=USER_REQ_CONTEXT.format(user_requirement=sub_task),
previous_impl="\n".join(working_memory),
context=USER_REQ_SUBTASK_WM_CONTEXT.format(
user_requirement=user_req,
subtask=subtask,
working_memory=retrieved_ltm,
),
previous_impl=working_memory,
)
messages = [
{"role": "system", "content": DEBUG_SYS_MSG},
Expand All @@ -106,19 +143,21 @@ def write_and_exec_code(
user_req: str,
subtask: str,
orig_code: str,
code_writer_call: Callable,
code_writer_call: Callable[..., str],
model: LLM,
tool_info: str,
exec: Execute,
retrieved_ltm: str,
max_retry: int = 3,
verbose: bool = False,
) -> Tuple[bool, str, str, Dict[str, List[str]]]:
success = False
counter = 0
reflection = ""

# TODO: add working memory to code_writer_call and debug_code
code = code_writer_call(user_req, subtask, tool_info, orig_code, model)
code = code_writer_call(
user_req, subtask, retrieved_ltm, tool_info, orig_code, model
)
success, result = exec.run_isolation(code)
working_memory: Dict[str, List[str]] = {}
while not success and counter < max_retry:
Expand All @@ -136,7 +175,9 @@ def write_and_exec_code(
PREV_CODE_CONTEXT.format(code=code, result=result)
)

code, reflection = debug_code(subtask, working_memory[subtask], model)
code, reflection = debug_code(
user_req, subtask, retrieved_ltm, "\n".join(working_memory[subtask]), model
)
success, result = exec.run_isolation(code)
counter += 1
if verbose:
Expand All @@ -148,7 +189,7 @@ def write_and_exec_code(
if success:
working_memory[subtask].append(
PREV_CODE_CONTEXT_WITH_REFLECTION.format(
code=code, result=result, reflection=reflection
reflection=reflection, code=code, result=result
)
)

Expand All @@ -162,12 +203,15 @@ def run_plan(
exec: Execute,
code: str,
tool_recommender: Sim,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
) -> Tuple[str, str, List[Dict[str, Any]], Dict[str, List[str]]]:
active_plan = [e for e in plan if "success" not in e or not e["success"]]
working_memory: Dict[str, List[str]] = {}
current_code = code
current_test = ""
retrieved_ltm = ""
working_memory: Dict[str, List[str]] = {}

for task in active_plan:
_LOGGER.info(
f"""
Expand All @@ -176,22 +220,29 @@ def run_plan(
tool_info = "\n".join(
[e["doc"] for e in tool_recommender.top_k(task["instruction"])]
)
success, code, result, task_memory = write_and_exec_code(

if long_term_memory is not None:
retrieved_ltm = "\n".join(
[e["doc"] for e in long_term_memory.top_k(task["instruction"], 1)]
)

success, code, result, working_memory_i = write_and_exec_code(
user_req,
task["instruction"],
current_code,
write_code if task["type"] == "code" else write_test,
coder,
tool_info,
exec,
verbose,
retrieved_ltm,
verbose=verbose,
)
if task["type"] == "code":
current_code = code
else:
current_test = code

working_memory.update(task_memory)
working_memory.update(working_memory_i)

if verbose:
_CONSOLE.print(
Expand Down Expand Up @@ -231,6 +282,7 @@ def __init__(
self,
timeout: int = 600,
tool_recommender: Optional[Sim] = None,
long_term_memory: Optional[Sim] = None,
verbose: bool = False,
) -> None:
self.planner = OpenAILLM(temperature=0.1, json_mode=True)
Expand All @@ -241,60 +293,86 @@ def __init__(
else:
self.tool_recommender = tool_recommender
self.verbose = verbose
self._working_memory: Dict[str, List[str]] = {}
if long_term_memory is not None:
if "doc" not in long_term_memory.df.columns:
raise ValueError("Long term memory must have a 'doc' column.")
self.long_term_memory = long_term_memory
self.max_retries = 3
if self.verbose:
_LOGGER.setLevel(logging.INFO)

def __call__(
self,
input: Union[List[Dict[str, str]], str],
image: Optional[Union[str, Path]] = None,
plan: Optional[List[Dict[str, Any]]] = None,
) -> str:
if isinstance(input, str):
input = [{"role": "user", "content": input}]
code, _ = self.chat_with_tests(input, image)
return code
results = self.chat_with_workflow(input, image, plan)
return results["code"] # type: ignore

def chat_with_tests(
def chat_with_workflow(
self,
chat: List[Dict[str, str]],
image: Optional[Union[str, Path]] = None,
) -> Tuple[str, str]:
plan: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
if len(chat) == 0:
raise ValueError("Input cannot be empty.")

user_req = chat[0]["content"]
if image is not None:
user_req += f" Image name {image}"
# append file names to all user messages
for chat_i in chat:
if chat_i["role"] == "user":
chat_i["content"] += f" Image name {image}"

working_code = ""
if plan is not None:
# grab the latest working code from a previous plan
for task in plan:
if "success" in task and "code" in task and task["success"]:
working_code = task["code"]

plan = write_plan(user_req, TOOL_DESCRIPTIONS, self.planner)
user_req, plan = write_plan(chat, plan, TOOL_DESCRIPTIONS, self.planner)
_LOGGER.info(
f"""Plan:
{tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
)
working_memory: Dict[str, List[str]] = {}

working_code = ""
working_test = ""
working_memory: Dict[str, List[str]] = {}
success = False
retries = 0

while not success:
while not success and retries < self.max_retries:
working_code, working_test, plan, working_memory_i = run_plan(
user_req,
plan,
self.coder,
self.exec,
working_code,
self.tool_recommender,
self.long_term_memory,
self.verbose,
)
success = all(task["success"] for task in plan)
working_memory.update(working_memory_i)

if not success:
# TODO: ask for feedback and replan
# return to user and request feedback
break

return working_code, working_test
retries += 1

return {
"code": working_code,
"test": working_test,
"success": success,
"working_memory": build_working_memory(working_memory),
"plan": plan,
}

def log_progress(self, description: str) -> None:
pass
Loading

0 comments on commit d6fd63e

Please sign in to comment.