Skip to content

Commit

Permalink
ensure artifact is saved
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 16, 2024
1 parent 666ab3c commit 906ee66
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
12 changes: 11 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def execute_code_action(
obs = str(result.logs)
if result.error:
obs += f"\n{result.error}"
__import__("ipdb").set_trace()
extract_and_save_files_to_artifacts(artifacts, code, obs, result)
return result, obs

Expand Down Expand Up @@ -323,6 +324,7 @@ def __init__(
agent: Optional[LMM] = None,
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
remote_artifacts_path: Optional[Union[str, Path]] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[Union[str, CodeInterpreter]] = None,
) -> None:
Expand Down Expand Up @@ -357,6 +359,14 @@ def __init__(
else Path(tempfile.NamedTemporaryFile(delete=False).name)
),
)
self.remote_artifacts_path = cast(
str,
(
Path(remote_artifacts_path)
if remote_artifacts_path is not None
else Path(WORKSPACE / "artifacts.pkl")
),
)

def __call__(
self,
Expand Down Expand Up @@ -433,7 +443,7 @@ def chat_with_artifacts(

if not artifacts:
# this is setting remote artifacts path
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")
artifacts = Artifacts(self.remote_artifacts_path, self.local_artifacts_path)

# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
code_interpreter = (
Expand Down
10 changes: 6 additions & 4 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ class Artifacts:
need to be in sync with the remote environment the VisionAgent is running in.
"""

def __init__(self, remote_save_path: Union[str, Path]) -> None:
def __init__(
self, remote_save_path: Union[str, Path], local_save_path: Union[str, Path]
) -> None:
self.remote_save_path = Path(remote_save_path)
self.local_save_path = Path(local_save_path)
self.artifacts: Dict[str, Any] = {}

self.code_sandbox_runtime = None
Expand Down Expand Up @@ -132,9 +135,7 @@ def show(self, uploaded_file_path: Optional[Union[str, Path]] = None) -> str:
return output_str

def save(self, local_path: Optional[Union[str, Path]] = None) -> None:
save_path = (
Path(local_path) if local_path is not None else self.remote_save_path
)
save_path = Path(local_path) if local_path is not None else self.local_save_path
with open(save_path, "wb") as f:
pkl.dump(self.artifacts, f)

Expand Down Expand Up @@ -876,6 +877,7 @@ def extract_and_save_files_to_artifacts(
list(artifacts.artifacts.keys()),
)
artifacts[new_name] = files[format][j]
artifacts.save()


META_TOOL_DOCSTRING = get_tool_documentation(
Expand Down

0 comments on commit 906ee66

Please sign in to comment.