Skip to content

Commit

Permalink
return artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 30, 2024
1 parent ac9a5e0 commit 32b1ce9
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_json
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
(
Path(local_artifacts_path)
if local_artifacts_path is not None
else "artifacts.pkl"
else Path(tempfile.NamedTemporaryFile(delete=False).name)
),
)

Expand All @@ -161,14 +162,14 @@ def __call__(
input = [{"role": "user", "content": input}]
if media is not None:
input[0]["media"] = [media]
results = self.chat_with_code(input, artifacts)
results, _ = self.chat_with_code(input, artifacts)
return results

def chat_with_code(
self,
chat: List[Message],
artifacts: Optional[Artifacts] = None,
) -> List[Message]:
) -> Tuple[List[Message], Artifacts]:
"""Chat with VisionAgent, it will use code to execute actions to accomplish
its tasks.
Expand All @@ -187,6 +188,7 @@ def chat_with_code(
raise ValueError("chat cannot be empty")

if not artifacts:
# this is setting remote artifacts path
artifacts = Artifacts(WORKSPACE / "artifacts.pkl")

with CodeInterpreterFactory.new_instance(
Expand Down Expand Up @@ -265,9 +267,8 @@ def chat_with_code(

if self.verbosity >= 1:
_LOGGER.info(obs)
int_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
# don't add execution results to internal chat
int_chat.append({"role": "observation", "content": obs})
orig_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
Expand All @@ -281,7 +282,7 @@ def chat_with_code(
)
artifacts.load(self.local_artifacts_path)
artifacts.save()
return orig_chat
return orig_chat, artifacts

def log_progress(self, data: Dict[str, Any]) -> None:
pass

0 comments on commit 32b1ce9

Please sign in to comment.