From 32b1ce92d6083cebe706e0dc5e393b656b959f3f Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 29 Aug 2024 19:13:08 -0700 Subject: [PATCH] return artifacts --- vision_agent/agent/vision_agent.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index df73cef4..5544b188 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,8 +1,9 @@ import copy import logging import os +import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from vision_agent.agent import Agent from vision_agent.agent.agent_utils import extract_json @@ -135,7 +136,7 @@ def __init__( ( Path(local_artifacts_path) if local_artifacts_path is not None - else "artifacts.pkl" + else Path(tempfile.NamedTemporaryFile(delete=False).name) ), ) @@ -161,14 +162,14 @@ def __call__( input = [{"role": "user", "content": input}] if media is not None: input[0]["media"] = [media] - results = self.chat_with_code(input, artifacts) + results, _ = self.chat_with_code(input, artifacts) return results def chat_with_code( self, chat: List[Message], artifacts: Optional[Artifacts] = None, - ) -> List[Message]: + ) -> Tuple[List[Message], Artifacts]: """Chat with VisionAgent, it will use code to execute actions to accomplish its tasks. @@ -187,6 +188,7 @@ def chat_with_code( raise ValueError("chat cannot be empty") if not artifacts: + # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") with CodeInterpreterFactory.new_instance( @@ -265,9 +267,8 @@ def chat_with_code( if self.verbosity >= 1: _LOGGER.info(obs) - int_chat.append( - {"role": "observation", "content": obs, "execution": result} - ) + # don't add execution results to internal chat + int_chat.append({"role": "observation", "content": obs}) orig_chat.append( {"role": "observation", "content": obs, "execution": result} ) @@ -281,7 +282,7 @@ def chat_with_code( ) artifacts.load(self.local_artifacts_path) artifacts.save() - return orig_chat + return orig_chat, artifacts def log_progress(self, data: Dict[str, Any]) -> None: pass