From 906ee6684e2561a205f5e85c438441e6880c9bbb Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 15 Oct 2024 20:24:41 -0700 Subject: [PATCH] ensure artifact is saved --- vision_agent/agent/vision_agent.py | 12 +++++++++++- vision_agent/tools/meta_tools.py | 10 ++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 64a8ff49..51745f53 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -155,6 +155,7 @@ def execute_code_action( obs = str(result.logs) if result.error: obs += f"\n{result.error}" + __import__("ipdb").set_trace() extract_and_save_files_to_artifacts(artifacts, code, obs, result) return result, obs @@ -323,6 +324,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, + remote_artifacts_path: Optional[Union[str, Path]] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, code_interpreter: Optional[Union[str, CodeInterpreter]] = None, ) -> None: @@ -357,6 +359,14 @@ def __init__( else Path(tempfile.NamedTemporaryFile(delete=False).name) ), ) + self.remote_artifacts_path = cast( + str, + ( + Path(remote_artifacts_path) + if remote_artifacts_path is not None + else Path(WORKSPACE / "artifacts.pkl") + ), + ) def __call__( self, @@ -433,7 +443,7 @@ def chat_with_artifacts( if not artifacts: # this is setting remote artifacts path - artifacts = Artifacts(WORKSPACE / "artifacts.pkl") + artifacts = Artifacts(self.remote_artifacts_path, self.local_artifacts_path) # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues code_interpreter = ( diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 10c44bac..d13e3731 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -87,8 +87,11 @@ class Artifacts: need to be in sync with the remote environment the VisionAgent is running in. """ - def __init__(self, remote_save_path: Union[str, Path]) -> None: + def __init__( + self, remote_save_path: Union[str, Path], local_save_path: Union[str, Path] + ) -> None: self.remote_save_path = Path(remote_save_path) + self.local_save_path = Path(local_save_path) self.artifacts: Dict[str, Any] = {} self.code_sandbox_runtime = None @@ -132,9 +135,7 @@ def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str: return output_str def save(self, local_path: Optional[Union[str, Path]] = None) -> None: - save_path = ( - Path(local_path) if local_path is not None else self.remote_save_path - ) + save_path = Path(local_path) if local_path is not None else self.local_save_path with open(save_path, "wb") as f: pkl.dump(self.artifacts, f) @@ -876,6 +877,7 @@ def extract_and_save_files_to_artifacts( list(artifacts.artifacts.keys()), ) artifacts[new_name] = files[format][j] + artifacts.save() META_TOOL_DOCSTRING = get_tool_documentation(