Skip to content

Commit

Permalink
added ability to write media files to artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 29, 2024
1 parent d83857e commit 51503b9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 27 deletions.
17 changes: 9 additions & 8 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BoilerplateCode:
pre_code = [
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_artifact, create_artifact, edit_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, write_media_artifact",
"artifacts = Artifacts('{remote_path}')",
"artifacts.load('{remote_path}')",
]
Expand Down Expand Up @@ -198,13 +198,14 @@ def chat_with_code(
for chat_i in int_chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = code_interpreter.upload_file(cast(str, media))
chat_i["content"] += f" Media name {media}" # type: ignore
# Save dummy value for now since we just need to know the path
# name in the key 'media'. Later on we can add artifact support
# for byte data.
artifacts.artifacts[Path(media).name] = None
media_list.append(media)
media = cast(str, media)
artifacts.artifacts[Path(media).name] = open(media, "rb").read()

media_remote_path = (
Path(code_interpreter.remote_path) / Path(media).name
)
chat_i["content"] += f" Media name {media_remote_path}" # type: ignore
media_list.append(media_remote_path)

int_chat = cast(
List[Message],
Expand Down
53 changes: 34 additions & 19 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def load(self, file_path: Union[str, Path]) -> None:
self.artifacts = pkl.load(f)
for k, v in self.artifacts.items():
if v is not None:
with open(self.remote_save_path.parent / k, "w") as f:
mode = "w" if isinstance(v, str) else "wb"
with open(self.remote_save_path.parent / k, mode) as f:
f.write(v)

def show(self) -> str:
Expand All @@ -87,7 +88,7 @@ def __iter__(self) -> Any:
def __getitem__(self, name: str) -> Any:
return self.artifacts[name]

def __setitem__(self, name: str, value: str) -> None:
def __setitem__(self, name: str, value: Any) -> None:
self.artifacts[name] = value

def __contains__(self, name: str) -> bool:
Expand Down Expand Up @@ -119,11 +120,11 @@ def view_lines(
return return_str


def open_artifact(
def open_code_artifact(
artifacts: Artifacts, name: str, line_num: int = 0, window_size: int = 100
) -> str:
"""Opens the provided artifact. If `line_num` is provided, the window will be moved
to include that line. It only shows the first 100 lines by default! Max
"""Opens the provided code artifact. If `line_num` is provided, the window will be
moved to include that line. It only shows the first 100 lines by default! Max
`window_size` supported is 2000.
Parameters:
Expand All @@ -148,8 +149,8 @@ def open_artifact(
return view_lines(lines, line_num, window_size, name, total_lines)


def create_artifact(artifacts: Artifacts, name: str) -> str:
"""Creates a new artifiact with the given name.
def create_code_artifact(artifacts: Artifacts, name: str) -> str:
"""Creates a new code artifiact with the given name.
Parameters:
artifacts (Artifacts): The artifacts object to add the new artifact to.
Expand All @@ -164,15 +165,15 @@ def create_artifact(artifacts: Artifacts, name: str) -> str:
return return_str


def edit_artifact(
def edit_code_artifact(
artifacts: Artifacts, name: str, start: int, end: int, content: str
) -> str:
"""Edits the given artifact with the provided content. The content will be inserted
between the `start` and `end` line numbers. If the `start` and `end` are the same,
the content will be inserted at the `start` line number. If the `end` is greater
than the total number of lines in the file, the content will be inserted at the end
of the file. If the `start` or `end` are negative, the function will return an
error message.
"""Edits the given code artifact with the provided content. The content will be
inserted between the `start` and `end` line numbers. If the `start` and `end` are
the same, the content will be inserted at the `start` line number. If the `end` is
greater than the total number of lines in the file, the content will be inserted at
the end of the file. If the `start` or `end` are negative, the function will return
an error message.
Parameters:
artifacts (Artifacts): The artifacts object to edit the artifact from.
Expand Down Expand Up @@ -237,7 +238,7 @@ def edit_artifact(

artifacts[name] = "".join(edited_lines)

return open_artifact(artifacts, name, cur_line)
return open_code_artifact(artifacts, name, cur_line)


def generate_vision_code(
Expand Down Expand Up @@ -274,7 +275,7 @@ def detect_dogs(image_path: str):
agent = va.agent.VisionAgentCoder()

fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
response = agent.chat_with_workflow(fixed_chat, test_multi_plan=False)
response = agent.chat_with_workflow(fixed_chat, test_multi_plan=True)
code = response["code"]
artifacts[name] = code
code_lines = code.splitlines(keepends=True)
Expand Down Expand Up @@ -335,6 +336,19 @@ def detect_dogs(image_path: str):
return view_lines(code_lines, 0, total_lines, name, total_lines)


def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
"""Writes a media file to the artifacts object.
Parameters:
artifacts (Artifacts): The artifacts object to save the media to.
local_path (str): The local path to the media file.
"""
with open(local_path, "rb") as f:
media = f.read()
artifacts[Path(local_path).name] = media
return f"[Media {Path(local_path).name} saved]"


def get_tool_descriptions() -> str:
"""Returns a description of all the tools that `generate_vision_code` has access to.
Helpful for answering questions about what types of vision tasks you can do with
Expand All @@ -345,10 +359,11 @@ def get_tool_descriptions() -> str:
META_TOOL_DOCSTRING = get_tool_documentation(
[
get_tool_descriptions,
open_artifact,
create_artifact,
edit_artifact,
open_code_artifact,
create_code_artifact,
edit_code_artifact,
generate_vision_code,
edit_vision_code,
write_media_artifact,
]
)

0 comments on commit 51503b9

Please sign in to comment.