From 55fc5982e6af391d07e85bdbaa4df15e18f4f74e Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 15 Oct 2024 21:02:46 -0700 Subject: [PATCH] upload and download artifacts per turn --- vision_agent/agent/vision_agent.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 1bb00621..42204190 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -39,7 +39,7 @@ class BoilerplateCode: "from typing import *", "from vision_agent.utils.execute import CodeInterpreter", "from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning", - "artifacts = Artifacts('{remote_path}')", + "artifacts = Artifacts('{remote_path}', '{remote_path}')", "artifacts.load('{remote_path}')", ] post_code = [ @@ -202,8 +202,10 @@ def _add_media_obs( obs_chat_elt: Message = {"role": "observation", "content": obs} media_obs = check_and_load_image(code_action) if media_obs and result.success: - # for view_media_artifact, we need to ensure the media is loaded - # locally so the conversation agent can actually see it + # for view_media_artifact, we need to ensure the media is loaded locally so + # the conversation agent can actually see it. We also download it here so we + # can check if it contains the actual media (note this is in addition to + # downloading it per turn). code_interpreter.download_file( str(remote_artifacts_path.name), str(local_artifacts_path), @@ -530,6 +532,10 @@ def chat_with_artifacts( ) while not finished and iterations < self.max_iterations: + # ensure we upload the artifacts before each turn, so any local + # modifications we made to it will be reflected in the remote + code_interpreter.upload_file(self.local_artifacts_path) + response = run_conversation(self.agent, int_chat) if self.verbosity >= 1: _LOGGER.info(response) @@ -622,13 +628,14 @@ def chat_with_artifacts( iterations += 1 last_response = response - # after running the agent, download the artifacts locally - code_interpreter.download_file( - str(remote_artifacts_path.name), str(self.local_artifacts_path) - ) - artifacts.load( - self.local_artifacts_path, Path(self.local_artifacts_path).parent - ) + # after each turn, download the artifacts locally + code_interpreter.download_file( + str(remote_artifacts_path.name), str(self.local_artifacts_path) + ) + artifacts.load( + self.local_artifacts_path, Path(self.local_artifacts_path).parent + ) + return orig_chat, artifacts def streaming_message(self, message: Dict[str, Any]) -> None: