Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image support for VA #162

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 23 additions & 58 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from vision_agent.agent.vision_agent_prompts import (
CODE,
FIX_BUG,
FULL_TASK,
PLAN,
REFLECT,
SIMPLE_TEST,
USER_REQ,
)
Expand Down Expand Up @@ -145,7 +143,7 @@ def write_plan(
tool_desc: str,
working_memory: str,
model: LMM,
) -> List[Dict[str, str]]:
) -> Dict[str, Any]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")
Expand All @@ -154,13 +152,14 @@ def write_plan(
context = USER_REQ.format(user_request=user_request)
prompt = PLAN.format(context=context, tool_desc=tool_desc, feedback=working_memory)
chat[-1]["content"] = prompt
return extract_json(model.chat(chat))["plan"] # type: ignore
return extract_json(model.chat(chat))


@traceable
def write_code(
coder: LMM,
chat: List[Message],
image_desc: str,
tool_info: str,
feedback: str,
) -> str:
Expand All @@ -173,6 +172,7 @@ def write_code(
docstring=tool_info,
question=user_request,
feedback=feedback,
image_desc=image_desc,
)
chat[-1]["content"] = prompt
return extract_code(coder(chat))
Expand Down Expand Up @@ -203,26 +203,9 @@ def write_test(
return extract_code(tester(chat))


@traceable
def reflect(
chat: List[Message],
plan: str,
code: str,
model: LMM,
) -> Dict[str, Union[str, bool]]:
chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")

user_request = chat[-1]["content"]
context = USER_REQ.format(user_request=user_request)
prompt = REFLECT.format(context=context, plan=plan, code=code)
chat[-1]["content"] = prompt
return extract_json(model(chat))


def write_and_test_code(
chat: List[Message],
image_desc: str,
tool_info: str,
tool_utils: str,
working_memory: List[Dict[str, str]],
Expand All @@ -241,7 +224,13 @@ def write_and_test_code(
"status": "started",
}
)
code = write_code(coder, chat, tool_info, format_memory(working_memory))
code = write_code(
coder,
chat,
image_desc,
f"{tool_info}\n{tool_utils}",
format_memory(working_memory),
)
test = write_test(
tester, chat, tool_utils, code, format_memory(working_memory), media
)
Expand Down Expand Up @@ -543,7 +532,6 @@ def __call__(
def chat_with_workflow(
self,
chat: List[Message],
self_reflection: bool = False,
display_visualization: bool = False,
) -> Dict[str, Any]:
"""Chat with Vision Agent and return intermediate information regarding the task.
Expand All @@ -554,7 +542,6 @@ def chat_with_workflow(
[{"role": "user", "content": "describe your task here..."}]
or if it contains media files, it should be in the format of:
[{"role": "user", "content": "describe your task here...", "media": ["image1.jpg", "image2.jpg"]}]
self_reflection (bool): Whether to reflect on the task and debug the code.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).

Expand All @@ -581,7 +568,10 @@ def chat_with_workflow(

int_chat = cast(
List[Message],
[{"role": c["role"], "content": c["content"]} for c in chat],
[
{"role": c["role"], "content": c["content"], "media": c["media"]}
for c in chat
],
)

code = ""
Expand All @@ -599,13 +589,14 @@ def chat_with_workflow(
"status": "started",
}
)
plan_i = write_plan(
planning = write_plan(
int_chat,
T.TOOL_DESCRIPTIONS,
format_memory(working_memory),
self.planner,
)
plan_i_str = "\n-".join([e["instructions"] for e in plan_i])
plan_i = planning["plan"]
image_desc = planning["image_desc"]

self.log_progress(
{
Expand All @@ -626,7 +617,10 @@ def chat_with_workflow(
self.verbosity,
)
results = write_and_test_code(
chat=int_chat,
chat=[
{"role": c["role"], "content": c["content"]} for c in int_chat
],
image_desc=image_desc,
tool_info=tool_info,
tool_utils=T.UTILITIES_DOCSTRING,
working_memory=working_memory,
Expand All @@ -644,35 +638,6 @@ def chat_with_workflow(
working_memory.extend(results["working_memory"]) # type: ignore
plan.append({"code": code, "test": test, "plan": plan_i})

if not self_reflection:
break

self.log_progress(
{
"type": "self_reflection",
"status": "started",
}
)
reflection = reflect(
int_chat,
FULL_TASK.format(
user_request=chat[0]["content"], subtasks=plan_i_str
),
code,
self.planner,
)
if self.verbosity > 0:
_LOGGER.info(f"Reflection: {reflection}")
feedback = cast(str, reflection["feedback"])
success = cast(bool, reflection["success"])
self.log_progress(
{
"type": "self_reflection",
"status": "completed" if success else "failed",
"payload": reflection,
}
)
working_memory.append({"code": f"{code}\n{test}", "feedback": feedback})
retries += 1

execution_result = cast(Execution, results["test_result"])
Expand Down
14 changes: 9 additions & 5 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
{feedback}

**Instructions**:
1. Based on the context and tools you have available, write a plan of subtasks to achieve the user request.
2. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan.
1. Based on the context and tools you have available, create a plan of subtasks to achieve the user request.
2. Provide a detailed description of the image, be sure to include any text you see in the image and whether or not the predominant object count is many, over a dozen, or just a few.
3. Go over the users request step by step and ensure each step is represented as a clear subtask in your plan.

Output a list of jsons in the following format

```json
{{
"image_desc": str # description of the image you are working with,
"thoughts": str # any thoughts you have about how to formulate the plan based on the image information,
"plan":
[
{{
Expand Down Expand Up @@ -67,12 +70,14 @@
**Previous Feedback**:
{feedback}

**Image Description**:
{image_desc}

**Instructions**:
1. **Understand and Clarify**: Make sure you understand the task.
2. **Algorithm/Method Selection**: Decide on the most efficient way.
2. **Algorithm/Method Selection**: Decide on the most efficient implementation utilizing the image description and tools available.
3. **Pseudocode Creation**: Write down the steps you will follow in pseudocode.
4. **Code Generation**: Translate your pseudocode into executable Python code. Ensure you use correct arguments, remember coordinates are always returned normalized from `vision_agent.tools`. All images from `vision_agent.tools` are in RGB format, red is (255, 0, 0) and blue is (0, 0, 255).
5. **Logging**: Log the output of the custom functions that were provided to you from `from vision_agent.tools import *`. Use a debug flag in the function parameters to toggle logging on and off.
"""

TEST = """
Expand Down Expand Up @@ -147,7 +152,6 @@ def find_text(image_path: str, text: str) -> str:
```
"""


SIMPLE_TEST = """
**Role**: As a tester, your task is to create a simple test case for the provided code. This test case should verify the fundamental functionality under normal conditions.

Expand Down
63 changes: 43 additions & 20 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import io
import json
import logging
import os
Expand All @@ -8,19 +9,48 @@

import requests
from openai import AzureOpenAI, OpenAI
from PIL import Image

import vision_agent.tools as T
from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT

_LOGGER = logging.getLogger(__name__)


def encode_image(image: Union[str, Path]) -> str:
with open(image, "rb") as f:
encoded_image = base64.b64encode(f.read()).decode("utf-8")
def encode_image_bytes(image: bytes) -> str:
image = Image.open(io.BytesIO(image)).convert("RGB") # type: ignore
buffer = io.BytesIO()
image.save(buffer, format="PNG") # type: ignore
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
return encoded_image


def encode_media(media: Union[str, Path]) -> str:
extension = "png"
extension = Path(media).suffix
if extension.lower() not in {
".jpg",
".jpeg",
".png",
".webp",
".bmp",
".mp4",
".mov",
}:
raise ValueError(f"Unsupported image extension: {extension}")

image_bytes = b""
if extension.lower() in {".mp4", ".mov"}:
frames = T.extract_frames(media)
image = frames[len(frames) // 2]
buffer = io.BytesIO()
Image.fromarray(image[0]).convert("RGB").save(buffer, format="PNG")
image_bytes = buffer.getvalue()
else:
image_bytes = open(media, "rb").read()
return encode_image_bytes(image_bytes)


TextOrImage = Union[str, List[Union[str, Path]]]
Message = Dict[str, TextOrImage]

Expand Down Expand Up @@ -54,7 +84,7 @@ def __init__(
self,
model_name: str = "gpt-4o",
api_key: Optional[str] = None,
max_tokens: int = 1024,
max_tokens: int = 4096,
json_mode: bool = False,
**kwargs: Any,
):
Expand Down Expand Up @@ -97,20 +127,14 @@ def chat(
fixed_c = {"role": c["role"]}
fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore
if "media" in c:
for image in c["media"]:
extension = Path(image).suffix
if extension.lower() == ".jpeg" or extension.lower() == ".jpg":
extension = "jpg"
elif extension.lower() == ".png":
extension = "png"
else:
raise ValueError(f"Unsupported image extension: {extension}")
encoded_image = encode_image(image)
for media in c["media"]:
encoded_media = encode_media(media)

fixed_c["content"].append( # type: ignore
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}", # type: ignore
"url": f"data:image/png;base64,{encoded_media}", # type: ignore
"detail": "low",
},
},
Expand Down Expand Up @@ -138,13 +162,12 @@ def generate(
]
if media and len(media) > 0:
for m in media:
extension = Path(m).suffix
encoded_image = encode_image(m)
encoded_media = encode_media(m)
message[0]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"url": f"data:image/png;base64,{encoded_media}",
"detail": "low",
},
},
Expand Down Expand Up @@ -241,7 +264,7 @@ def __init__(
api_key: Optional[str] = None,
api_version: str = "2024-02-01",
azure_endpoint: Optional[str] = None,
max_tokens: int = 1024,
max_tokens: int = 4096,
json_mode: bool = False,
**kwargs: Any,
):
Expand Down Expand Up @@ -312,7 +335,7 @@ def chat(
fixed_chat = []
for message in chat:
if "media" in message:
message["images"] = [encode_image(m) for m in message["media"]]
message["images"] = [encode_media(m) for m in message["media"]]
del message["media"]
fixed_chat.append(message)
url = f"{self.url}/chat"
Expand Down Expand Up @@ -343,7 +366,7 @@ def generate(
json_data = json.dumps(data)
if media and len(media) > 0:
for m in media:
data["images"].append(encode_image(m)) # type: ignore
data["images"].append(encode_media(m)) # type: ignore

response = requests.post(url, data=json_data)

Expand Down
13 changes: 9 additions & 4 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,10 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution":
return Execution(
error=Error(
name=exec.__class__.__name__,
value=str(exec),
traceback_raw=traceback_raw,
value=_remove_escape_and_color_codes(str(exec)),
traceback_raw=[
_remove_escape_and_color_codes(line) for line in traceback_raw
],
)
)

Expand All @@ -373,8 +375,11 @@ def from_e2b_execution(exec: E2BExecution) -> "Execution": # type: ignore
error=(
Error(
name=exec.error.name,
value=exec.error.value,
traceback_raw=exec.error.traceback_raw,
value=_remove_escape_and_color_codes(exec.error.value),
traceback_raw=[
_remove_escape_and_color_codes(line)
for line in exec.error.traceback_raw
],
)
if exec.error
else None
Expand Down
Loading