Skip to content

Commit

Permalink
automatically save files to artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 12, 2024
1 parent 7e69a79 commit fe26385
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 71 deletions.
24 changes: 18 additions & 6 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
META_TOOL_DOCSTRING,
Artifacts,
check_and_load_image,
extract_and_save_files_to_artifacts,
use_extra_vision_agent_args,
)
from vision_agent.utils import CodeInterpreterFactory
Expand All @@ -36,7 +37,7 @@ class BoilerplateCode:
pre_code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning",
"artifacts = Artifacts('{remote_path}')",
"artifacts.load('{remote_path}')",
]
Expand Down Expand Up @@ -94,7 +95,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
elif chat_i["role"] == "observation":
conversation += f"OBSERVATION:\n{chat_i['content']}\n\n"
elif chat_i["role"] == "assistant":
conversation += f"AGENT: {format_agent_message(chat_i['content'])}\n\n"
conversation += f"AGENT: {format_agent_message(chat_i['content'])}\n\n" # type: ignore
else:
raise ValueError(f"role {chat_i['role']} is not supported")

Expand Down Expand Up @@ -127,11 +128,15 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:


def execute_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
artifacts: Artifacts,
code: str,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
BoilerplateCode.add_boilerplate(code, remote_path=artifact_remote_path)
)
extract_and_save_files_to_artifacts(artifacts, code)

obs = str(result.logs)
if result.error:
Expand All @@ -140,6 +145,7 @@ def execute_code_action(


def execute_user_code_action(
artifacts: Artifacts,
last_user_message: Message,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
Expand All @@ -159,7 +165,7 @@ def execute_user_code_action(
if user_code_action is not None:
user_code_action = use_extra_vision_agent_args(user_code_action, False)
user_result, user_obs = execute_code_action(
user_code_action, code_interpreter, artifact_remote_path
artifacts, user_code_action, code_interpreter, artifact_remote_path
)
if user_result.error:
user_obs += f"\n{user_result.error}"
Expand Down Expand Up @@ -385,7 +391,10 @@ def chat_with_artifacts(
self.streaming_message({"role": "observation", "content": artifacts_loaded})

user_result, user_obs = execute_user_code_action(
last_user_message, code_interpreter, str(remote_artifacts_path)
artifacts,
last_user_message,
code_interpreter,
str(remote_artifacts_path),
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
Expand Down Expand Up @@ -456,7 +465,10 @@ def chat_with_artifacts(

if code_action is not None:
result, obs = execute_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
artifacts,
code_action,
code_interpreter,
str(remote_artifacts_path),
)

media_obs = check_and_load_image(code_action)
Expand Down
7 changes: 3 additions & 4 deletions vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@
16| return count
[End of artifact]
AGENT: <thinking>I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.</thinking>
AGENT: <thinking>I have generated the code to count the workers with helmets in the image, I must now run the code and print the output.</thinking>
<execute_python>
from code import count_workers_with_helmets
print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')
print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))
</execute_python>
<let_user_respond>false</let_user_respond>
Expand Down Expand Up @@ -234,11 +234,10 @@ def count_workers_with_helmets(image_path: str, output_path: str):
count = 0
[End of edits]
AGENT: <thinking>I have updated the code to use the fine tuned model, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.</thinking>
AGENT: <thinking>I have updated the code to use the fine tuned model, I must now run the code and print the output.</thinking>
<execute_python>
from code import count_workers_with_helmets
print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))
write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')
</execute_python>
<let_user_respond>false</let_user_respond>
Expand Down
128 changes: 67 additions & 61 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import difflib
import json
import os
Expand All @@ -9,7 +8,6 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy as np
from IPython.display import display
from redbaron import RedBaron # type: ignore

Expand All @@ -22,8 +20,7 @@
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.utils.execute import Execution, MimeType
from vision_agent.utils.image_utils import convert_to_b64, numpy_to_bytes
from vision_agent.utils.video import frames_to_bytes
from vision_agent.utils.image_utils import convert_to_b64

CURRENT_FILE = None
CURRENT_LINE = 0
Expand Down Expand Up @@ -393,19 +390,6 @@ def generate_vision_plan(
redisplay_results(response.test_results)
response.test_results = None
artifacts[name] = response.model_dump_json()
media_names = extract_json(
AnthropicLMM()( # type: ignore
f"""Extract any media file names from this output in the following JSON format:
{{"media": ["image1.jpg", "image2.jpg"]}}
{artifacts[name]}"""
)
)
if "media" in media_names and isinstance(media_names, dict):
for media in media_names["media"]:
if isinstance(media, str):
with open(media, "rb") as f:
artifacts[media] = f.read()

output_str = f"[Start Plan Context, saved at {name}]"
for plan in response.plans.keys():
Expand Down Expand Up @@ -466,6 +450,12 @@ def detect_dogs(image_path: str):
test_multi_plan=test_multi_plan,
custom_tool_names=custom_tool_names,
)

# capture and save any files that were saved in the code to the artifacts
extract_and_save_files_to_artifacts(
artifacts, response["code"] + "\n" + response["test"]
)

redisplay_results(response["test_result"])
code = response["code"]
artifacts[name] = code
Expand Down Expand Up @@ -546,6 +536,11 @@ def detect_dogs(image_path: str):
test_multi_plan=False,
custom_tool_names=custom_tool_names,
)
# capture and save any files that were saved in the code to the artifacts
extract_and_save_files_to_artifacts(
artifacts, response["code"] + "\n" + response["test"]
)

redisplay_results(response["test_result"])
code = response["code"]
artifacts[name] = code
Expand All @@ -567,49 +562,6 @@ def detect_dogs(image_path: str):
return view_lines(code_lines, 0, total_lines, name, total_lines)


def write_media_artifact(
artifacts: Artifacts,
name: str,
media: Union[str, np.ndarray, List[np.ndarray]],
fps: Optional[float] = None,
) -> str:
"""Writes a media file to the artifacts object.
Parameters:
artifacts (Artifacts): The artifacts object to save the media to.
name (str): The name of the media artifact to save.
media (Union[str, np.ndarray, List[np.ndarray]]): The media to save, can either
be a file path, single image or list of frames for a video.
fps (Optional[float]): The frames per second if you are writing a video.
"""
if isinstance(media, str):
with open(media, "rb") as f:
media_bytes = f.read()
elif isinstance(media, list):
media_bytes = frames_to_bytes(media, fps=fps if fps is not None else 1.0)
elif isinstance(media, np.ndarray):
media_bytes = numpy_to_bytes(media)
else:
print(f"[Invalid media type {type(media)}]")
return f"[Invalid media type {type(media)}]"
artifacts[name] = media_bytes
print(f"[Media {name} saved]")
display(
{
MimeType.APPLICATION_ARTIFACT: json.dumps(
{
"name": name,
"action": "create",
"content": base64.b64encode(media_bytes).decode("utf-8"),
"contentType": "media_output",
}
)
},
raw=True,
)
return f"[Media {name} saved]"


def list_artifacts(artifacts: Artifacts) -> str:
"""Lists all the artifacts that have been loaded into the artifacts object."""
output_str = artifacts.show()
Expand Down Expand Up @@ -813,6 +765,61 @@ def use_object_detection_fine_tuning(
return diff


def extract_and_save_files_to_artifacts(artifacts: Artifacts, code: str) -> None:
"""Extracts and saves files used in the code to the artifacts object.
Parameters:
artifacts (Artifacts): The artifacts object to save the files to.
code (str): The code to extract the files from.
"""
try:
response = extract_json(
AnthropicLMM()( # type: ignore
f"""You are a helpful AI assistant. Your job is to look at a snippet of code and return the file paths that are being saved in the file. Below is the code snippet:
```python
{code}
```
Return the file paths in the following JSON format:
{{"file_paths": ["/path/to/image1.jpg", "/other/path/to/data.json"]}}"""
)
)
except json.JSONDecodeError:
return

text_file_ext = [
".txt",
".md",
"rtf",
".html",
".htm",
"xml",
".json",
".csv",
".tsv",
".yaml",
".yml",
".toml",
".conf",
".env" ".ini",
".log",
".py",
".java",
".js",
".cpp",
".c" ".sql",
".sh",
]

if "file_paths" in response and isinstance(response["file_paths"], list):
for file_path in response["file_paths"]:
read_mode = "r" if Path(file_path).suffix in text_file_ext else "rb"
if Path(file_path).is_file():
with open(file_path, read_mode) as f:
artifacts[Path(file_path).name] = f.read()


META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
Expand All @@ -822,7 +829,6 @@ def use_object_detection_fine_tuning(
generate_vision_plan,
generate_vision_code,
edit_vision_code,
write_media_artifact,
view_media_artifact,
object_detection_fine_tuning,
use_object_detection_fine_tuning,
Expand Down

0 comments on commit fe26385

Please sign in to comment.