Skip to content

Commit bf50f74

Browse files
committed
added dynamic re-planning
1 parent 8d3c60d commit bf50f74

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

vision_agent/agent/vision_agent_v2.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
from pathlib import Path
4-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
55

66
import pandas as pd
77
from rich.console import Console
@@ -33,6 +33,15 @@
3333
_CONSOLE = Console()
3434

3535

36+
def build_working_memory(working_memory: Mapping[str, List[str]]) -> Sim:
37+
data: Mapping[str, List[str]] = {"desc": [], "doc": []}
38+
for key, value in working_memory.items():
39+
data["desc"].append(key)
40+
data["doc"].append("\n".join(value))
41+
df = pd.DataFrame(data) # type: ignore
42+
return Sim(df, sim_key="desc")
43+
44+
3645
def extract_code(code: str) -> str:
3746
if "```python" in code:
3847
code = code[code.find("```python") + len("```python") :]
@@ -41,12 +50,21 @@ def extract_code(code: str) -> str:
4150

4251

4352
def write_plan(
44-
user_requirements: str, tool_desc: str, model: LLM
45-
) -> List[Dict[str, Any]]:
53+
chat: List[Dict[str, str]],
54+
plan: Optional[List[Dict[str, Any]]],
55+
tool_desc: str,
56+
model: LLM,
57+
) -> Tuple[str, List[Dict[str, Any]]]:
58+
# Get last user request
59+
if chat[-1]["role"] != "user":
60+
raise ValueError("Last chat message must be from the user.")
61+
user_requirements = chat[-1]["content"]
62+
4663
context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
47-
prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc)
48-
plan = json.loads(model(prompt).replace("```", "").strip())
49-
return plan["plan"] # type: ignore
64+
prompt = PLAN.format(context=context, plan=str(plan), tool_desc=tool_desc)
65+
chat[-1]["content"] = prompt
66+
plan = json.loads(model.chat(chat).replace("```", "").strip())
67+
return plan["user_req"], plan["plan"] # type: ignore
5068

5169

5270
def write_code(
@@ -123,7 +141,7 @@ def write_and_exec_code(
123141
user_req: str,
124142
subtask: str,
125143
orig_code: str,
126-
code_writer_call: Callable,
144+
code_writer_call: Callable[..., str],
127145
model: LLM,
128146
tool_info: str,
129147
exec: Execute,
@@ -191,6 +209,7 @@ def run_plan(
191209
current_test = ""
192210
retrieved_ltm = ""
193211
working_memory: Dict[str, List[str]] = {}
212+
194213
for task in active_plan:
195214
_LOGGER.info(
196215
f"""
@@ -209,7 +228,7 @@ def run_plan(
209228
user_req,
210229
task["instruction"],
211230
current_code,
212-
write_code if task["type"] == "code" else write_test,
231+
write_code if task["type"] == "code" else write_test, # type: ignore
213232
coder,
214233
tool_info,
215234
exec,
@@ -277,43 +296,55 @@ def __init__(
277296
if "doc" not in long_term_memory.df.columns:
278297
raise ValueError("Long term memory must have a 'doc' column.")
279298
self.long_term_memory = long_term_memory
299+
self.max_retries = 3
280300
if self.verbose:
281301
_LOGGER.setLevel(logging.INFO)
282302

283303
def __call__(
284304
self,
285305
input: Union[List[Dict[str, str]], str],
286306
image: Optional[Union[str, Path]] = None,
307+
plan: Optional[List[Dict[str, Any]]] = None,
287308
) -> str:
288309
if isinstance(input, str):
289310
input = [{"role": "user", "content": input}]
290-
code, _ = self.chat_with_tests(input, image)
291-
return code
311+
results = self.chat_with_workflow(input, image, plan)
312+
return results["code"] # type: ignore
292313

293-
def chat_with_tests(
314+
def chat_with_workflow(
294315
self,
295316
chat: List[Dict[str, str]],
296317
image: Optional[Union[str, Path]] = None,
297-
) -> Tuple[str, str]:
318+
plan: Optional[List[Dict[str, Any]]] = None,
319+
) -> Dict[str, Any]:
298320
if len(chat) == 0:
299321
raise ValueError("Input cannot be empty.")
300322

301-
user_req = chat[0]["content"]
302323
if image is not None:
303-
user_req += f" Image name {image}"
324+
# append file names to all user messages
325+
for chat_i in chat:
326+
if chat_i["role"] == "user":
327+
chat_i["content"] += f" Image name {image}"
328+
329+
working_code = ""
330+
if plan is not None:
331+
# grab the latest working code from a previous plan
332+
for task in plan:
333+
if "success" in task and "code" in task and task["success"]:
334+
working_code = task["code"]
304335

305-
plan = write_plan(user_req, TOOL_DESCRIPTIONS, self.planner)
336+
user_req, plan = write_plan(chat, plan, TOOL_DESCRIPTIONS, self.planner)
306337
_LOGGER.info(
307338
f"""Plan:
308339
{tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
309340
)
310341

311-
working_code = ""
312342
working_test = ""
343+
working_memory: Dict[str, List[str]] = {}
313344
success = False
345+
retries = 0
314346

315-
__import__("ipdb").set_trace()
316-
while not success:
347+
while not success and retries < self.max_retries:
317348
working_code, working_test, plan, working_memory_i = run_plan(
318349
user_req,
319350
plan,
@@ -325,22 +356,21 @@ def chat_with_tests(
325356
self.verbose,
326357
)
327358
success = all(task["success"] for task in plan)
328-
self._working_memory.update(working_memory_i)
359+
working_memory.update(working_memory_i)
329360

330361
if not success:
331-
# TODO: ask for feedback and replan
362+
# return to user and request feedback
332363
break
333364

334-
return working_code, working_test
365+
retries += 1
335366

336-
@property
337-
def working_memory(self) -> Sim:
338-
data: Dict[str, List[str]] = {"desc": [], "doc": []}
339-
for key, value in self._working_memory.items():
340-
data["desc"].append(key)
341-
data["doc"].append("\n".join(value))
342-
df = pd.DataFrame(data)
343-
return Sim(df, sim_key="desc")
367+
return {
368+
"code": working_code,
369+
"test": working_test,
370+
"success": success,
371+
"working_memory": build_working_memory(working_memory),
372+
"plan": plan,
373+
}
344374

345375
def log_progress(self, description: str) -> None:
346376
pass

vision_agent/agent/vision_agent_v2_prompt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@
3737
- For each subtask, you should provide a short instruction on what to do. Ensure the subtasks are large enough to be meaningful, encompassing multiple lines of code.
3838
- You do not need to have the agent rewrite any tool functionality you already have, you should instead instruct it to utilize one or more of those tools in each subtask.
3939
- You can have agents either write coding tasks, to code some functionality or testing tasks to test previous functionality.
40+
- If a current plan exists, examine each item in the plan to determine if it was successful. If there was an item that failed, i.e. 'success': False, then you should rewrite that item and all subsequent items to ensure that the rewritten plan is successful.
4041
4142
Output a list of jsons in the following format:
4243
4344
```json
4445
{{
46+
"user_req": str, # "a summarized version of the user requirement"
4547
"plan":
4648
[
4749
{{

0 commit comments

Comments
 (0)