Skip to content

Commit

Permalink
Separate out planner (#256)
Browse files Browse the repository at this point in the history
* separated out planner, renamed chat methods

* fixed circular imports

* added type for plan context

* add planner as separate call to vision agent

* export plan context

* fixed circular imports

* fixed wrong key

* better json parsing

* more test cases for json parsing

* have planner visualize results

* add more guard rails to remove double chat

* revert changes with planning step for now

* revert to original prompts

* fix type issue

* fix format issue

* skip examples for flake8

* fix names and readme

* fixed type error

* fix countgd integ test

* synced code with new code interpreter arg
  • Loading branch information
dillonalaird authored Oct 11, 2024
1 parent 296af2d commit ca1dc57
Show file tree
Hide file tree
Showing 15 changed files with 1,174 additions and 740 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Linting
run: |
# stop the build if there are Python syntax errors or undefined names
poetry run flake8 . --exclude .venv --count --show-source --statistics
poetry run flake8 . --exclude .venv,examples --count --show-source --statistics
- name: Check Format
run: |
poetry run black --check --diff --color .
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ continuing, for example it may want to execute code and look at the output befor
letting the user respond.

### Chatting and Artifacts
If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s
If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s
are a way to sync files between local and remote environments. The agent will read and
write to the artifact object, which is just a pickle object, when it wants to save or
load files.
Expand All @@ -118,7 +118,7 @@ with open("image.png", "rb") as f:
artifacts["image.png"] = f.read()

agent = va.agent.VisionAgent()
response, artifacts = agent.chat_with_code(
response, artifacts = agent.chat_with_artifacts(
[
{
"role": "user",
Expand Down Expand Up @@ -298,11 +298,11 @@ mode by passing in the verbose argument:
```

### Detailed Usage
You can also have it return more information by calling `chat_with_workflow`. The format
You can also have it return more information by calling `generate_code`. The format
of the input is a list of dictionaries with the keys `role`, `content`, and `media`:

```python
>>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
>>> print(results)
{
"code": "from vision_agent.tools import ..."
Expand Down Expand Up @@ -331,7 +331,7 @@ conv = [
"media": ["workers.png"],
}
]
result = agent.chat_with_workflow(conv)
result = agent.generate_code(conv)
code = result["code"]
conv.append({"role": "assistant", "content": code})
conv.append(
Expand All @@ -340,7 +340,7 @@ conv.append(
"content": "Can you also return the number of workers wearing safety gear?",
}
)
result = agent.chat_with_workflow(conv)
result = agent.generate_code(conv)
```


Expand Down
12 changes: 6 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ continuing, for example it may want to execute code and look at the output befor
letting the user respond.

### Chatting and Artifacts
If you run `chat_with_code` you will also notice an `Artifact` object. `Artifact`'s
If you run `chat_with_artifacts` you will also notice an `Artifact` object. `Artifact`'s
are a way to sync files between local and remote environments. The agent will read and
write to the artifact object, which is just a pickle object, when it wants to save or
load files.
Expand All @@ -114,7 +114,7 @@ with open("image.png", "rb") as f:
artifacts["image.png"] = f.read()

agent = va.agent.VisionAgent()
response, artifacts = agent.chat_with_code(
response, artifacts = agent.chat_with_artifacts(
[
{
"role": "user",
Expand Down Expand Up @@ -294,11 +294,11 @@ mode by passing in the verbose argument:
```

### Detailed Usage
You can also have it return more information by calling `chat_with_workflow`. The format
You can also have it return more information by calling `generate_code`. The format
of the input is a list of dictionaries with the keys `role`, `content`, and `media`:

```python
>>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
>>> results = agent.generate_code([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?", "media": ["jar.jpg"]}])
>>> print(results)
{
"code": "from vision_agent.tools import ..."
Expand Down Expand Up @@ -327,7 +327,7 @@ conv = [
"media": ["workers.png"],
}
]
result = agent.chat_with_workflow(conv)
result = agent.generate_code(conv)
code = result["code"]
conv.append({"role": "assistant", "content": code})
conv.append(
Expand All @@ -336,7 +336,7 @@ conv.append(
"content": "Can you also return the number of workers wearing safety gear?",
}
)
result = agent.chat_with_workflow(conv)
result = agent.generate_code(conv)
```


Expand Down
32 changes: 21 additions & 11 deletions examples/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"style": {"bottom": "calc(50% - 4.25rem", "right": "0.4rem"},
}
# set artifacts remote_path to WORKSPACE
artifacts = va.tools.Artifacts(WORKSPACE / "artifacts.pkl")
artifacts = va.tools.meta_tools.Artifacts(WORKSPACE / "artifacts.pkl")
if Path("artifacts.pkl").exists():
artifacts.load("artifacts.pkl")
else:
Expand Down Expand Up @@ -109,16 +109,26 @@ def main():
len(st.session_state.messages) == 0
or prompt != st.session_state.messages[-1]["content"]
):
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
messages.chat_message("user").write(prompt)
message_thread = threading.Thread(
target=update_messages,
args=(st.session_state.messages, message_lock),
)
message_thread.daemon = True
message_thread.start()
# occassionally resends the last user message twice
user_messages = [
msg
for msg in st.session_state.messages
if msg["role"] == "user"
]
last_user_message = None
if len(user_messages) > 0:
last_user_message = user_messages[-1]["content"]
if last_user_message is None or last_user_message != prompt:
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
messages.chat_message("user").write(prompt)
message_thread = threading.Thread(
target=update_messages,
args=(st.session_state.messages, message_lock),
)
message_thread.daemon = True
message_thread.start()
st.session_state.input_text = ""

with tabs[1]:
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,4 @@ def test_countgd_example_based_counting() -> None:
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert [res["label"] for res in result] == ["object"] * 24
7 changes: 7 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ def test():
assert "import os" in out
assert "!pip install pandas" not in out
assert "!pip install dummy" in out


def test_chat_agent_case():
a = """{"thoughts": "The user has chosen to use the plan with owl_v2 and specified a threshold of 0.4. I'll now generate the vision code based on this plan and the user's modification.", "response": "Certainly! I'll generate the code using owl_v2 with a threshold of 0.4 as you requested. Let me create that for you now.\n\n<execute_python>generate_vision_code(artifacts, 'count_workers_with_helmets.py', 'Can you write code to count the number of workers wearing helmets?', media=['/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png'], plan={'thoughts': 'Using owl_v2_image seems most appropriate as it can detect and count multiple objects given a text prompt. This tool is specifically designed for object detection tasks like counting workers wearing helmets.', 'instructions': ['Load the image using load_image(\'/Users/dillonlaird/landing.ai/vision-agent/examples/chat/workspace/workers.png\')', 'Use owl_v2_image with the prompt \'worker wearing helmet\' to detect and count workers with helmets', 'Count the number of detections returned by owl_v2_image to get the final count of workers wearing helmets']}, plan_thoughts='Use a threshold of 0.4 as specified by the user', plan_context_artifact='worker_helmet_plan.json')</execute_python>", "let_user_respond": false}"""
a_json = extract_json(a)
assert "thoughts" in a_json
assert "response" in a_json
8 changes: 8 additions & 0 deletions vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,11 @@
OpenAIVisionAgentCoder,
VisionAgentCoder,
)
from .vision_agent_planner import (
AnthropicVisionAgentPlanner,
AzureVisionAgentPlanner,
OllamaVisionAgentPlanner,
OpenAIVisionAgentPlanner,
PlanContext,
VisionAgentPlanner,
)
78 changes: 76 additions & 2 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
import logging
import re
import sys
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from rich.console import Console
from rich.style import Style
from rich.syntax import Syntax

import vision_agent.tools as T

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
_CONSOLE = Console()


def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -41,11 +48,16 @@ def _strip_markdown_code(inp_str: str) -> str:

def extract_json(json_str: str) -> Dict[str, Any]:
json_str_mod = json_str.replace("\n", " ").strip()
json_str_mod = json_str_mod.replace("'", '"')
json_str_mod = json_str_mod.replace(": True", ": true").replace(
": False", ": false"
)

# sometimes the json is in single quotes
try:
return json.loads(json_str_mod.replace("'", '"')) # type: ignore
except json.JSONDecodeError:
pass

try:
return json.loads(json_str_mod) # type: ignore
except json.JSONDecodeError:
Expand Down Expand Up @@ -83,3 +95,65 @@ def remove_installs_from_code(code: str) -> str:
pattern = r"\n!pip install.*?(\n|\Z)\n"
code = re.sub(pattern, "", code, flags=re.DOTALL)
return code


def format_memory(memory: List[Dict[str, str]]) -> str:
output_str = ""
for i, m in enumerate(memory):
output_str += f"### Feedback {i}:\n"
output_str += f"Code {i}:\n```python\n{m['code']}```\n\n"
output_str += f"Feedback {i}: {m['feedback']}\n\n"
if "edits" in m:
output_str += f"Edits {i}:\n{m['edits']}\n"
output_str += "\n"

return output_str


def format_plans(plans: Dict[str, Any]) -> str:
plan_str = ""
for k, v in plans.items():
plan_str += "\n" + f"{k}: {v['thoughts']}\n"
plan_str += " -" + "\n -".join([e for e in v["instructions"]])

return plan_str


class DefaultImports:
"""Container for default imports used in the code execution."""

common_imports = [
"import os",
"import numpy as np",
"from vision_agent.tools import *",
"from typing import *",
"from pillow_heif import register_heif_opener",
"register_heif_opener()",
]

@staticmethod
def to_code_string() -> str:
return "\n".join(DefaultImports.common_imports + T.__new_tools__)

@staticmethod
def prepend_imports(code: str) -> str:
"""Run this method to prepend the default imports to the code.
NOTE: be sure to run this method after the custom tools have been registered.
"""
return DefaultImports.to_code_string() + "\n\n" + code


def print_code(title: str, code: str, test: Optional[str] = None) -> None:
_CONSOLE.print(title, style=Style(bgcolor="dark_orange3", bold=True))
_CONSOLE.print("=" * 30 + " Code " + "=" * 30)
_CONSOLE.print(
Syntax(
DefaultImports.prepend_imports(code),
"python",
theme="gruvbox-dark",
line_numbers=True,
)
)
if test:
_CONSOLE.print("=" * 30 + " Test " + "=" * 30)
_CONSOLE.print(Syntax(test, "python", theme="gruvbox-dark", line_numbers=True))
Loading

0 comments on commit ca1dc57

Please sign in to comment.