Skip to content

Commit 2c3fa2e

Browse files
committed
add AVA
1 parent 91dbacf commit 2c3fa2e

File tree

2 files changed

+435
-0
lines changed

2 files changed

+435
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
import json
2+
import logging
3+
from pathlib import Path
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
6+
from rich.console import Console
7+
from rich.syntax import Syntax
8+
from tabulate import tabulate
9+
10+
from vision_agent.agent import Agent
11+
from vision_agent.llm import LLM, OpenAILLM
12+
from vision_agent.tools.tools_v2 import TOOL_DESCRIPTIONS, TOOLS_DF
13+
from vision_agent.utils import Execute, Sim
14+
15+
from .automated_vision_agent_prompt import (
16+
CODE,
17+
CODE_SYS_MSG,
18+
DEBUG,
19+
DEBUG_EXAMPLE,
20+
DEBUG_SYS_MSG,
21+
PLAN,
22+
PREV_CODE_CONTEXT,
23+
PREV_CODE_CONTEXT_WITH_REFLECTION,
24+
TEST,
25+
USER_REQ_CONTEXT,
26+
USER_REQ_SUBTASK_CONTEXT,
27+
)
28+
29+
logging.basicConfig(level=logging.INFO)
30+
_LOGGER = logging.getLogger(__name__)
31+
_MAX_TABULATE_COL_WIDTH = 80
32+
_CONSOLE = Console()
33+
34+
35+
def extract_code(code: str) -> str:
36+
if "```python" in code:
37+
code = code[code.find("```python") + len("```python") :]
38+
code = code[: code.find("```")]
39+
return code
40+
41+
42+
def write_plan(
43+
user_requirements: str, tool_desc: str, model: LLM
44+
) -> List[Dict[str, Any]]:
45+
context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
46+
prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc)
47+
plan = json.loads(model(prompt).replace("```", "").strip())
48+
return plan["plan"]
49+
50+
51+
def write_code(user_req: str, subtask: str, tool_info: str, code: str, model: LLM) -> str:
52+
prompt = CODE.format(
53+
context=USER_REQ_SUBTASK_CONTEXT.format(user_requirement=user_req, subtask=subtask),
54+
tool_info=tool_info,
55+
code=code,
56+
)
57+
messages = [
58+
{"role": "system", "content": CODE_SYS_MSG},
59+
{"role": "user", "content": prompt},
60+
]
61+
code = model.chat(messages)
62+
return extract_code(code)
63+
64+
65+
def write_test(user_req: str, subtask: str, tool_info: str, code: str, model: LLM) -> str:
66+
prompt = TEST.format(
67+
context=USER_REQ_SUBTASK_CONTEXT.format(user_requirement=user_req, subtask=subtask),
68+
tool_info=tool_info,
69+
code=code,
70+
)
71+
messages = [
72+
{"role": "system", "content": CODE_SYS_MSG},
73+
{"role": "user", "content": prompt},
74+
]
75+
code = model.chat(messages)
76+
return extract_code(code)
77+
78+
79+
def debug_code(sub_task: str, working_memory: List[str], model: LLM) -> Tuple[str, str]:
80+
# Make debug model output JSON
81+
if hasattr(model, "kwargs"):
82+
model.kwargs["response_format"] = {"type": "json_object"}
83+
prompt = DEBUG.format(
84+
debug_example=DEBUG_EXAMPLE,
85+
context=USER_REQ_CONTEXT.format(user_requirement=sub_task),
86+
previous_impl="\n".join(working_memory),
87+
)
88+
messages = [
89+
{"role": "system", "content": DEBUG_SYS_MSG},
90+
{"role": "user", "content": prompt},
91+
]
92+
code_and_ref = json.loads(model.chat(messages).replace("```", "").strip())
93+
if hasattr(model, "kwargs"):
94+
del model.kwargs["response_format"]
95+
return extract_code(code_and_ref["improved_impl"]), code_and_ref["reflection"]
96+
97+
98+
def write_and_exec_code(
99+
user_req: str,
100+
subtask: str,
101+
orig_code: str,
102+
code_writer_call: Callable,
103+
model: LLM,
104+
tool_info: str,
105+
exec: Execute,
106+
max_retry: int = 3,
107+
verbose: bool = False,
108+
) -> Tuple[bool, str, str, Dict[str, List[str]]]:
109+
success = False
110+
counter = 0
111+
reflection = ""
112+
113+
# TODO: add working memory to code_writer_call and debug_code
114+
code = code_writer_call(user_req, subtask, tool_info, orig_code, model)
115+
success, result = exec.run_isolation(code)
116+
working_memory: Dict[str, List[str]] = {}
117+
while not success and counter < max_retry:
118+
if subtask not in working_memory:
119+
working_memory[subtask] = []
120+
121+
if reflection:
122+
working_memory[subtask].append(
123+
PREV_CODE_CONTEXT_WITH_REFLECTION.format(
124+
code=code, result=result, reflection=reflection
125+
)
126+
)
127+
else:
128+
working_memory[subtask].append(
129+
PREV_CODE_CONTEXT.format(code=code, result=result)
130+
)
131+
132+
code, reflection = debug_code(subtask, working_memory[subtask], model)
133+
success, result = exec.run_isolation(code)
134+
counter += 1
135+
if verbose:
136+
_CONSOLE.print(
137+
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
138+
)
139+
_LOGGER.info(f"\tDebugging reflection, result: {reflection}, {result}")
140+
141+
if success:
142+
working_memory[subtask].append(
143+
PREV_CODE_CONTEXT_WITH_REFLECTION.format(
144+
code=code, result=result, reflection=reflection
145+
)
146+
)
147+
148+
return success, code, result, working_memory
149+
150+
151+
def run_plan(
152+
user_req: str,
153+
plan: List[Dict[str, Any]],
154+
coder: LLM,
155+
exec: Execute,
156+
code: str,
157+
tool_recommender: Sim,
158+
verbose: bool = False,
159+
) -> Tuple[str, List[Dict[str, Any]]]:
160+
active_plan = [e for e in plan if "success" not in e or not e["success"]]
161+
working_memory: Dict[str, List[str]] = {}
162+
for task in active_plan:
163+
_LOGGER.info(
164+
f"""
165+
{tabulate(tabular_data=[task], headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
166+
)
167+
tool_info = "\n".join([e["doc"] for e in tool_recommender.top_k(task["instruction"])])
168+
success, code, result, task_memory = write_and_exec_code(
169+
user_req,
170+
task["instruction"],
171+
code,
172+
write_code if task["type"] == "code" else write_test,
173+
coder,
174+
tool_info,
175+
exec,
176+
verbose,
177+
)
178+
working_memory.update(task_memory)
179+
180+
if verbose:
181+
_CONSOLE.print(
182+
Syntax(code, "python", theme="gruvbox-dark", line_numbers=True)
183+
)
184+
_LOGGER.info(f"\tCode success, result: {success}, {str(result)}")
185+
186+
task["success"] = success
187+
task["result"] = result
188+
task["code"] = code
189+
190+
if not success:
191+
break
192+
193+
return code, plan
194+
195+
196+
class AutomatedVisionAgent(Agent):
197+
def __init__(
198+
self,
199+
timeout: int = 600,
200+
tool_recommender: Optional[Sim] = None,
201+
verbose: bool = False,
202+
) -> None:
203+
self.planner = OpenAILLM(temperature=0.1, json_mode=True)
204+
self.coder = OpenAILLM(temperature=0.1)
205+
self.exec = Execute(timeout=timeout)
206+
if tool_recommender is None:
207+
self.tool_recommender = Sim(TOOLS_DF, sim_key="desc")
208+
else:
209+
self.tool_recommender = tool_recommender
210+
self.long_term_memory = []
211+
self.verbose = verbose
212+
if self.verbose:
213+
_LOGGER.setLevel(logging.INFO)
214+
215+
def __call__(
216+
self,
217+
input: Union[List[Dict[str, str]], str],
218+
image: Optional[Union[str, Path]] = None,
219+
) -> str:
220+
if isinstance(input, str):
221+
input = [{"role": "user", "content": input}]
222+
return self.chat(input, image)
223+
224+
def chat(
225+
self,
226+
chat: List[Dict[str, str]],
227+
image: Optional[Union[str, Path]] = None,
228+
) -> str:
229+
if len(chat) == 0:
230+
raise ValueError("Input cannot be empty.")
231+
232+
user_req = chat[0]["content"]
233+
if image is not None:
234+
user_req += f" Image name {image}"
235+
236+
plan = write_plan(user_req, TOOL_DESCRIPTIONS, self.planner)
237+
_LOGGER.info(
238+
f"""Plan:
239+
{tabulate(tabular_data=plan, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"""
240+
)
241+
working_memory: Dict[str, List[str]] = {}
242+
243+
working_code = ""
244+
working_test = ""
245+
success = False
246+
247+
while not success:
248+
working_code, plan = run_plan(
249+
user_req,
250+
plan,
251+
self.coder,
252+
self.exec,
253+
working_code,
254+
self.tool_recommender,
255+
self.verbose,
256+
)
257+
success = all(task["success"] for task in plan)
258+
259+
if not success:
260+
pass
261+
262+
return working_code
263+
264+
def log_progress(self, description: str) -> None:
265+
pass

0 commit comments

Comments
 (0)