Skip to content

Commit ca1dc57

Browse files
authored
Separate out planner (#256)
* separated out planner, renamed chat methods * fixed circular imports * added type for plan context * add planner as separate call to vision agent * export plan context * fixed circular imports * fixed wrong key * better json parsing * more test cases for json parsing * have planner visualize results * add more guard rails to remove double chat * revert changes with planning step for now * revert to original prompts * fix type issue * fix format issue * skip examples for flake8 * fix names and readme * fixed type error * fix countgd integ test * synced code with new code interpreter arg
1 parent 296af2d commit ca1dc57

File tree

15 files changed

+1174
-740
lines changed

15 files changed

+1174
-740
lines changed

.github/workflows/ci_cd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
- name: Linting
4444
run: |
4545
# stop the build if there are Python syntax errors or undefined names
46-
poetry run flake8 . --exclude .venv --count --show-source --statistics
46+
poetry run flake8 . --exclude .venv,examples --count --show-source --statistics
4747
- name: Check Format
4848
run: |
4949
poetry run black --check --diff --color .

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ continuing, for example it may want to execute code and look at the output befor
101101
letting the user respond.
102102

103103
### Chatting and Artifacts
104-
If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s
104+
If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s
105105
are a way to sync files between local and remote environments. The agent will read and
106106
write to the artifact object, which is just a pickle object, when it wants to save or
107107
load files.
@@ -118,7 +118,7 @@ with open("image.png", "rb") as f:
118118
artifacts["image.png"] = f.read()
119119

120120
agent = va.agent.VisionAgent()
121-
response, artifacts = agent.chat_with_code(
121+
response, artifacts = agent.chat_with_artifacts(
122122
[
123123
{
124124
"role": "user",
@@ -298,11 +298,11 @@ mode by passing in the verbose argument:
298298
```
299299

300300
### Detailed Usage
301-
You can also have it return more information by calling `chat_with_workflow`. The format
301+
You can also have it return more information by calling `generate_code`. The format
302302
of the input is a list of dictionaries with the keys `role`, `content`, and `media`:
303303

304304
```python
305-
>>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
305+
>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
306306
>>> print(results)
307307
{
308308
"code": "from vision_agent.tools import ..."
@@ -331,7 +331,7 @@ conv = [
331331
"media": ["workers.png"],
332332
}
333333
]
334-
result = agent.chat_with_workflow(conv)
334+
result = agent.generate_code(conv)
335335
code = result["code"]
336336
conv.append({"role": "assistant", "content": code})
337337
conv.append(
@@ -340,7 +340,7 @@ conv.append(
340340
"content": "Can you also return the number of workers wearing safety gear?",
341341
}
342342
)
343-
result = agent.chat_with_workflow(conv)
343+
result = agent.generate_code(conv)
344344
```
345345

346346

docs/index.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ continuing, for example it may want to execute code and look at the output befor
9797
letting the user respond.
9898

9999
### Chatting and Artifacts
100-
If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s
100+
If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s
101101
are a way to sync files between local and remote environments. The agent will read and
102102
write to the artifact object, which is just a pickle object, when it wants to save or
103103
load files.
@@ -114,7 +114,7 @@ with open("image.png", "rb") as f:
114114
artifacts["image.png"] = f.read()
115115

116116
agent = va.agent.VisionAgent()
117-
response, artifacts = agent.chat_with_code(
117+
response, artifacts = agent.chat_with_artifacts(
118118
[
119119
{
120120
"role": "user",
@@ -294,11 +294,11 @@ mode by passing in the verbose argument:
294294
```
295295

296296
### Detailed Usage
297-
You can also have it return more information by calling `chat_with_workflow`. The format
297+
You can also have it return more information by calling `generate_code`. The format
298298
of the input is a list of dictionaries with the keys `role`, `content`, and `media`:
299299

300300
```python
301-
>>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
301+
>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
302302
>>> print(results)
303303
{
304304
"code": "from vision_agent.tools import ..."
@@ -327,7 +327,7 @@ conv = [
327327
"media": ["workers.png"],
328328
}
329329
]
330-
result = agent.chat_with_workflow(conv)
330+
result = agent.generate_code(conv)
331331
code = result["code"]
332332
conv.append({"role": "assistant", "content": code})
333333
conv.append(
@@ -336,7 +336,7 @@ conv.append(
336336
"content": "Can you also return the number of workers wearing safety gear?",
337337
}
338338
)
339-
result = agent.chat_with_workflow(conv)
339+
result = agent.generate_code(conv)
340340
```
341341

342342

examples/chat/app.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"style": {"bottom": "calc(50% - 4.25rem", "right": "0.4rem"},
2828
}
2929
# set artifacts remote_path to WORKSPACE
30-
artifacts = va.tools.Artifacts(WORKSPACE / "artifacts.pkl")
30+
artifacts = va.tools.meta_tools.Artifacts(WORKSPACE / "artifacts.pkl")
3131
if Path("artifacts.pkl").exists():
3232
artifacts.load("artifacts.pkl")
3333
else:
@@ -109,16 +109,26 @@ def main():
109109
len(st.session_state.messages) == 0
110110
or prompt != st.session_state.messages[-1]["content"]
111111
):
112-
st.session_state.messages.append(
113-
{"role": "user", "content": prompt}
114-
)
115-
messages.chat_message("user").write(prompt)
116-
message_thread = threading.Thread(
117-
target=update_messages,
118-
args=(st.session_state.messages, message_lock),
119-
)
120-
message_thread.daemon = True
121-
message_thread.start()
112+
# occassionally resends the last user message twice
113+
user_messages = [
114+
msg
115+
for msg in st.session_state.messages
116+
if msg["role"] == "user"
117+
]
118+
last_user_message = None
119+
if len(user_messages) > 0:
120+
last_user_message = user_messages[-1]["content"]
121+
if last_user_message is None or last_user_message != prompt:
122+
st.session_state.messages.append(
123+
{"role": "user", "content": prompt}
124+
)
125+
messages.chat_message("user").write(prompt)
126+
message_thread = threading.Thread(
127+
target=update_messages,
128+
args=(st.session_state.messages, message_lock),
129+
)
130+
message_thread.daemon = True
131+
message_thread.start()
122132
st.session_state.input_text = ""
123133

124134
with tabs[1]:

tests/integ/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,4 +413,4 @@ def test_countgd_example_based_counting() -> None:
413413
image=img,
414414
)
415415
assert len(result) == 24
416-
assert [res["label"] for res in result] == ["coin"] * 24
416+
assert [res["label"] for res in result] == ["object"] * 24

tests/unit/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,10 @@ def test():
6363
assert "import os" in out
6464
assert "!pip install pandas" not in out
6565
assert "!pip install dummy" in out
66+
67+
68+
def test_chat_agent_case():
69+
a = """{"thoughts": "The user has chosen to use the plan with owl_v2 and specified a threshold of 0.4. I'll now generate the vision code based on this plan and the user's modification.", "response": "Certainly! I'll generate the code using owl_v2 with a threshold of 0.4 as you requested. Let me create that for you now.\n\n<execute_python>generate_vision_code(artifacts, 'count_workers_with_helmets.py', 'Can you write code to count the number of workers wearing helmets?', media=['/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png'], plan={'thoughts': 'Using owl_v2_image seems most appropriate as it can detect and count multiple objects given a text prompt. This tool is specifically designed for object detection tasks like counting workers wearing helmets.', 'instructions': ['Load the image using load_image(\'/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png\')', 'Use owl_v2_image with the prompt \'worker wearing helmet\' to detect and count workers with helmets', 'Count the number of detections returned by owl_v2_image to get the final count of workers wearing helmets']}, plan_thoughts='Use a threshold of 0.4 as specified by the user', plan_context_artifact='worker_helmet_plan.json')</execute_python>", "let_user_respond": false}"""
70+
a_json = extract_json(a)
71+
assert "thoughts" in a_json
72+
assert "response" in a_json

vision_agent/agent/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,11 @@
77
OpenAIVisionAgentCoder,
88
VisionAgentCoder,
99
)
10+
from .vision_agent_planner import (
11+
AnthropicVisionAgentPlanner,
12+
AzureVisionAgentPlanner,
13+
OllamaVisionAgentPlanner,
14+
OpenAIVisionAgentPlanner,
15+
PlanContext,
16+
VisionAgentPlanner,
17+
)

vision_agent/agent/agent_utils.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
import logging
33
import re
44
import sys
5-
from typing import Any, Dict, Optional
5+
from typing import Any, Dict, List, Optional
6+
7+
from rich.console import Console
8+
from rich.style import Style
9+
from rich.syntax import Syntax
10+
11+
import vision_agent.tools as T
612

713
logging.basicConfig(stream=sys.stdout)
814
_LOGGER = logging.getLogger(__name__)
15+
_CONSOLE = Console()
916

1017

1118
def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
@@ -41,11 +48,16 @@ def _strip_markdown_code(inp_str: str) -> str:
4148

4249
def extract_json(json_str: str) -> Dict[str, Any]:
4350
json_str_mod = json_str.replace("\n", " ").strip()
44-
json_str_mod = json_str_mod.replace("'", '"')
4551
json_str_mod = json_str_mod.replace(": True", ": true").replace(
4652
": False", ": false"
4753
)
4854

55+
# sometimes the json is in single quotes
56+
try:
57+
return json.loads(json_str_mod.replace("'", '"')) # type: ignore
58+
except json.JSONDecodeError:
59+
pass
60+
4961
try:
5062
return json.loads(json_str_mod) # type: ignore
5163
except json.JSONDecodeError:
@@ -83,3 +95,65 @@ def remove_installs_from_code(code: str) -> str:
8395
pattern = r"\n!pip install.*?(\n|\Z)\n"
8496
code = re.sub(pattern, "", code, flags=re.DOTALL)
8597
return code
98+
99+
100+
def format_memory(memory: List[Dict[str, str]]) -> str:
101+
output_str = ""
102+
for i, m in enumerate(memory):
103+
output_str += f"### Feedback {i}:\n"
104+
output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
105+
output_str += f"Feedback {i}: {m['feedback']}\n\n"
106+
if "edits" in m:
107+
output_str += f"Edits {i}:\n{m['edits']}\n"
108+
output_str += "\n"
109+
110+
return output_str
111+
112+
113+
def format_plans(plans: Dict[str, Any]) -> str:
114+
plan_str = ""
115+
for k, v in plans.items():
116+
plan_str += "\n" + f"{k}: {v['thoughts']}\n"
117+
plan_str += " -" + "\n -".join([e for e in v["instructions"]])
118+
119+
return plan_str
120+
121+
122+
class DefaultImports:
123+
"""Container for default imports used in the code execution."""
124+
125+
common_imports = [
126+
"import os",
127+
"import numpy as np",
128+
"from vision_agent.tools import *",
129+
"from typing import *",
130+
"from pillow_heif import register_heif_opener",
131+
"register_heif_opener()",
132+
]
133+
134+
@staticmethod
135+
def to_code_string() -> str:
136+
return "\n".join(DefaultImports.common_imports + T.__new_tools__)
137+
138+
@staticmethod
139+
def prepend_imports(code: str) -> str:
140+
"""Run this method to prepend the default imports to the code.
141+
NOTE: be sure to run this method after the custom tools have been registered.
142+
"""
143+
return DefaultImports.to_code_string() + "\n\n" + code
144+
145+
146+
def print_code(title: str, code: str, test: Optional[str] = None) -> None:
147+
_CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
148+
_CONSOLE.print("=" * 30 + " Code " + "=" * 30)
149+
_CONSOLE.print(
150+
Syntax(
151+
DefaultImports.prepend_imports(code),
152+
"python",
153+
theme="gruvbox-dark",
154+
line_numbers=True,
155+
)
156+
)
157+
if test:
158+
_CONSOLE.print("=" * 30 + " Test " + "=" * 30)
159+
_CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))

0 commit comments

Comments
 (0)