Skip to content

Commit 91cce76

Browse files
authored
feat: support artifact name display (#236)
* feat: support artifact name display * fix lint
1 parent b66a8c9 commit 91cce76

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

vision_agent/agent/vision_agent.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vision_agent.lmm import LMM, Message, OpenAILMM
1616
from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image
1717
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
18+
from vision_agent.tools.tools import extract_frames, save_video
1819
from vision_agent.utils import CodeInterpreterFactory
1920
from vision_agent.utils.execute import CodeInterpreter, Execution
2021

@@ -224,9 +225,20 @@ def chat_with_code(
224225
for media in chat_i["media"]:
225226
if type(media) is str and media.startswith(("http", "https")):
226227
# TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
227-
file_path = Path(media).name
228-
ndarray = load_image(media)
229-
save_image(ndarray, file_path)
228+
file_path = str(
229+
Path(self.local_artifacts_path).parent
230+
/ Path(media).name
231+
)
232+
if file_path.lower().endswith(
233+
".mp4"
234+
) or file_path.lower().endswith(".mov"):
235+
video_frames = extract_frames(media)
236+
save_video(
237+
[frame for frame, _ in video_frames], file_path
238+
)
239+
else:
240+
ndarray = load_image(media)
241+
save_image(ndarray, file_path)
230242
media = file_path
231243
else:
232244
media = cast(str, media)

vision_agent/tools/meta_tools.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,27 @@ def redisplay_results(execution: Execution) -> None:
5353
"""
5454
for result in execution.results:
5555
if result.text is not None:
56-
display({MimeType.TEXT_PLAIN: result.text})
56+
display({MimeType.TEXT_PLAIN: result.text}, raw=True)
5757
if result.html is not None:
58-
display({MimeType.TEXT_HTML: result.html})
58+
display({MimeType.TEXT_HTML: result.html}, raw=True)
5959
if result.markdown is not None:
60-
display({MimeType.TEXT_MARKDOWN: result.markdown})
60+
display({MimeType.TEXT_MARKDOWN: result.markdown}, raw=True)
6161
if result.svg is not None:
62-
display({MimeType.IMAGE_SVG: result.svg})
62+
display({MimeType.IMAGE_SVG: result.svg}, raw=True)
6363
if result.png is not None:
64-
display({MimeType.IMAGE_PNG: result.png})
64+
display({MimeType.IMAGE_PNG: result.png}, raw=True)
6565
if result.jpeg is not None:
66-
display({MimeType.IMAGE_JPEG: result.jpeg})
66+
display({MimeType.IMAGE_JPEG: result.jpeg}, raw=True)
6767
if result.mp4 is not None:
68-
display({MimeType.VIDEO_MP4_B64: result.mp4})
68+
display({MimeType.VIDEO_MP4_B64: result.mp4}, raw=True)
6969
if result.latex is not None:
70-
display({MimeType.TEXT_LATEX: result.latex})
70+
display({MimeType.TEXT_LATEX: result.latex}, raw=True)
7171
if result.json is not None:
72-
display({MimeType.APPLICATION_JSON: result.json})
72+
display({MimeType.APPLICATION_JSON: result.json}, raw=True)
73+
if result.artifact_name is not None:
74+
display({MimeType.TEXT_ARTIFACT_NAME: result.artifact_name}, raw=True)
7375
if result.extra is not None:
74-
display(result.extra)
76+
display(result.extra, raw=True)
7577

7678

7779
class Artifacts:
@@ -208,7 +210,7 @@ def create_code_artifact(artifacts: Artifacts, name: str) -> str:
208210
return_str = f"[Artifact {name} created]"
209211
print(return_str)
210212

211-
display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
213+
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
212214
return return_str
213215

214216

@@ -292,7 +294,7 @@ def edit_code_artifact(
292294

293295
artifacts[name] = "".join(edited_lines)
294296

295-
display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
297+
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
296298
return open_code_artifact(artifacts, name, cur_line)
297299

298300

@@ -348,7 +350,7 @@ def detect_dogs(image_path: str):
348350
code_lines = code.splitlines(keepends=True)
349351
total_lines = len(code_lines)
350352

351-
display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
353+
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
352354
return view_lines(code_lines, 0, total_lines, name, total_lines)
353355

354356

@@ -413,7 +415,7 @@ def detect_dogs(image_path: str):
413415
code_lines = code.splitlines(keepends=True)
414416
total_lines = len(code_lines)
415417

416-
display({MimeType.APPLICATION_JSON: {"last_artifact": name}})
418+
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
417419
return view_lines(code_lines, 0, total_lines, name, total_lines)
418420

419421

@@ -427,6 +429,7 @@ def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
427429
with open(local_path, "rb") as f:
428430
media = f.read()
429431
artifacts[Path(local_path).name] = media
432+
display({MimeType.TEXT_ARTIFACT_NAME: Path(local_path).name}, raw=True)
430433
return f"[Media {Path(local_path).name} saved]"
431434

432435

@@ -592,6 +595,8 @@ def replacer(match: re.Match) -> str:
592595

593596
diff = get_diff_with_prompts(name, code, new_code)
594597
print(diff)
598+
599+
display({MimeType.TEXT_ARTIFACT_NAME: name}, raw=True)
595600
return diff
596601

597602

vision_agent/utils/execute.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class MimeType(str, Enum):
5656
TEXT_LATEX = "text/latex"
5757
APPLICATION_JSON = "application/json"
5858
APPLICATION_JAVASCRIPT = "application/javascript"
59+
TEXT_ARTIFACT_NAME = "text/artifact/name"
5960

6061

6162
class FileSerializer:
@@ -103,6 +104,7 @@ class Result:
103104
latex: Optional[str] = None
104105
json: Optional[Dict[str, Any]] = None
105106
javascript: Optional[str] = None
107+
artifact_name: Optional[str] = None
106108
extra: Optional[Dict[str, Any]] = None
107109
"Extra data that can be included. Not part of the standard types."
108110

@@ -127,6 +129,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
127129
self.latex = data.pop(MimeType.TEXT_LATEX, None)
128130
self.json = data.pop(MimeType.APPLICATION_JSON, None)
129131
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
132+
self.artifact_name = data.pop(MimeType.TEXT_ARTIFACT_NAME, None)
130133
self.extra = data
131134
# Only keeping the PNG representation if both PNG and JPEG are present
132135
if self.png and self.jpeg:
@@ -204,6 +207,8 @@ def formats(self) -> Iterable[str]:
204207
formats.append("javascript")
205208
if self.mp4:
206209
formats.append("mp4")
210+
if self.artifact_name:
211+
formats.append("artifact_name")
207212
if self.extra:
208213
formats.extend(iter(self.extra))
209214
return formats

0 commit comments

Comments
 (0)