Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support artifact name display #236

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vision_agent.lmm import LMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
from vision_agent.tools.tools import extract_frames, save_video
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, Execution

Expand Down Expand Up @@ -224,9 +225,20 @@ def chat_with_code(
for media in chat_i["media"]:
if type(media) is str and media.startswith(("http", "https")):
# TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
file_path = Path(media).name
ndarray = load_image(media)
save_image(ndarray, file_path)
file_path = str(
Path(self.local_artifacts_path).parent
/ Path(media).name
)
if file_path.lower().endswith(
".mp4"
) or file_path.lower().endswith(".mov"):
video_frames = extract_frames(media)
save_video(
[frame for frame, _ in video_frames], file_path
)
else:
ndarray = load_image(media)
save_image(ndarray, file_path)
media = file_path
else:
media = cast(str, media)
Expand Down
33 changes: 19 additions & 14 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,27 @@ def redisplay_results(execution: Execution) -> None:
"""
for result in execution.results:
if result.text is not None:
display({MimeType.TEXT_PLAIN: result.text})
display({MimeType.TEXT_PLAIN: result.text}, raw=True)
if result.html is not None:
display({MimeType.TEXT_HTML: result.html})
display({MimeType.TEXT_HTML: result.html}, raw=True)
if result.markdown is not None:
display({MimeType.TEXT_MARKDOWN: result.markdown})
display({MimeType.TEXT_MARKDOWN: result.markdown}, raw=True)
if result.svg is not None:
display({MimeType.IMAGE_SVG: result.svg})
display({MimeType.IMAGE_SVG: result.svg}, raw=True)
if result.png is not None:
display({MimeType.IMAGE_PNG: result.png})
display({MimeType.IMAGE_PNG: result.png}, raw=True)
if result.jpeg is not None:
display({MimeType.IMAGE_JPEG: result.jpeg})
display({MimeType.IMAGE_JPEG: result.jpeg}, raw=True)
if result.mp4 is not None:
display({MimeType.VIDEO_MP4_B64: result.mp4})
display({MimeType.VIDEO_MP4_B64: result.mp4}, raw=True)
if result.latex is not None:
display({MimeType.TEXT_LATEX: result.latex})
display({MimeType.TEXT_LATEX: result.latex}, raw=True)
if result.json is not None:
display({MimeType.APPLICATION_JSON: result.json})
display({MimeType.APPLICATION_JSON: result.json}, raw=True)
if result.artifact_name is not None:
display({MimeType.TEXT_ARTIFACT_NAME: result.artifact_name}, raw=True)
if result.extra is not None:
display(result.extra)
display(result.extra, raw=True)


class Artifacts:
Expand Down Expand Up @@ -208,7 +210,7 @@ def create_code_artifact(artifacts: Artifacts, name: str) -> str:
return_str = f"[Artifact {name} created]"
print(return_str)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return return_str


Expand Down Expand Up @@ -292,7 +294,7 @@ def edit_code_artifact(

artifacts[name] = "".join(edited_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return open_code_artifact(artifacts, name, cur_line)


Expand Down Expand Up @@ -348,7 +350,7 @@ def detect_dogs(image_path: str):
code_lines = code.splitlines(keepends=True)
total_lines = len(code_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return view_lines(code_lines, 0, total_lines, name, total_lines)


Expand Down Expand Up @@ -413,7 +415,7 @@ def detect_dogs(image_path: str):
code_lines = code.splitlines(keepends=True)
total_lines = len(code_lines)

display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return view_lines(code_lines, 0, total_lines, name, total_lines)


Expand All @@ -427,6 +429,7 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
with open(local_path, "rb") as f:
media = f.read()
artifacts[Path(local_path).name] = media
display({MimeType.TEXT_ARTIFACT_NAME: Path(local_path).name}, raw=True)
return f"[Media {Path(local_path).name} saved]"


Expand Down Expand Up @@ -592,6 +595,8 @@ def replacer(match: re.Match) -> str:

diff = get_diff_with_prompts(name, code, new_code)
print(diff)

display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
return diff


Expand Down
5 changes: 5 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class MimeType(str, Enum):
TEXT_LATEX = "text/latex"
APPLICATION_JSON = "application/json"
APPLICATION_JAVASCRIPT = "application/javascript"
TEXT_ARTIFACT_NAME = "text/artifact/name"


class FileSerializer:
Expand Down Expand Up @@ -103,6 +104,7 @@ class Result:
latex: Optional[str] = None
json: Optional[Dict[str, Any]] = None
javascript: Optional[str] = None
artifact_name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
"Extra data that can be included. Not part of the standard types."

Expand All @@ -127,6 +129,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
self.latex = data.pop(MimeType.TEXT_LATEX, None)
self.json = data.pop(MimeType.APPLICATION_JSON, None)
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
self.artifact_name = data.pop(MimeType.TEXT_ARTIFACT_NAME, None)
self.extra = data
# Only keeping the PNG representation if both PNG and JPEG are present
if self.png and self.jpeg:
Expand Down Expand Up @@ -204,6 +207,8 @@ def formats(self) -> Iterable[str]:
formats.append("javascript")
if self.mp4:
formats.append("mp4")
if self.artifact_name:
formats.append("artifact_name")
if self.extra:
formats.extend(iter(self.extra))
return formats
Expand Down
Loading