Skip to content

Commit

Permalink
upload and download artifacts per turn
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 16, 2024
1 parent b58e48d commit 55fc598
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BoilerplateCode:
"from typing import *",
"from vision_agent.utils.execute import CodeInterpreter",
"from vision_agent.tools.meta_tools import Artifacts, open_code_artifact, create_code_artifact, edit_code_artifact, get_tool_descriptions, generate_vision_code, edit_vision_code, view_media_artifact, object_detection_fine_tuning, use_object_detection_fine_tuning",
"artifacts = Artifacts('{remote_path}')",
"artifacts = Artifacts('{remote_path}', '{remote_path}')",
"artifacts.load('{remote_path}')",
]
post_code = [
Expand Down Expand Up @@ -202,8 +202,10 @@ def _add_media_obs(
obs_chat_elt: Message = {"role": "observation", "content": obs}
media_obs = check_and_load_image(code_action)
if media_obs and result.success:
# for view_media_artifact, we need to ensure the media is loaded
# locally so the conversation agent can actually see it
# for view_media_artifact, we need to ensure the media is loaded locally so
# the conversation agent can actually see it. We also download it here so we
# can check if it contains the actual media (note this is in addition to
# downloading it per turn).
code_interpreter.download_file(
str(remote_artifacts_path.name),
str(local_artifacts_path),
Expand Down Expand Up @@ -530,6 +532,10 @@ def chat_with_artifacts(
)

while not finished and iterations < self.max_iterations:
# ensure we upload the artifacts before each turn, so any local
# modifications we made to it will be reflected in the remote
code_interpreter.upload_file(self.local_artifacts_path)

response = run_conversation(self.agent, int_chat)
if self.verbosity >= 1:
_LOGGER.info(response)
Expand Down Expand Up @@ -622,13 +628,14 @@ def chat_with_artifacts(
iterations += 1
last_response = response

# after running the agent, download the artifacts locally
code_interpreter.download_file(
str(remote_artifacts_path.name), str(self.local_artifacts_path)
)
artifacts.load(
self.local_artifacts_path, Path(self.local_artifacts_path).parent
)
# after each turn, download the artifacts locally
code_interpreter.download_file(
str(remote_artifacts_path.name), str(self.local_artifacts_path)
)
artifacts.load(
self.local_artifacts_path, Path(self.local_artifacts_path).parent
)

return orig_chat, artifacts

def streaming_message(self, message: Dict[str, Any]) -> None:
Expand Down

0 comments on commit 55fc598

Please sign in to comment.