diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index dc910300..78a2ecae 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -92,19 +92,26 @@ def __init__(self, remote_save_path: Union[str, Path]) -> None: self.code_sandbox_runtime = None - def load(self, file_path: Union[str, Path]) -> None: - """Loads are artifacts into the remote environment. If an artifact value is None - it will skip loading it. + def load( + self, + artifacts_path: Union[str, Path], + load_to: Optional[Union[str, Path]] = None, + ) -> None: + """Loads are artifacts into the load_to path. If load_to is None, it will load + into remote_save_path. If an artifact value is None it will skip loading it. Parameters: - file_path (Union[str, Path]): The file path to load the artifacts from + artifacts_path (Union[str, Path]): The file path to load the artifacts from """ - with open(file_path, "rb") as f: + with open(artifacts_path, "rb") as f: self.artifacts = pkl.load(f) + + load_to = self.remote_save_path.parent if load_to is None else Path(load_to) + for k, v in self.artifacts.items(): if v is not None: mode = "w" if isinstance(v, str) else "wb" - with open(self.remote_save_path.parent / k, mode) as f: + with open(load_to / k, mode) as f: f.write(v) def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str: