From 02ee6e42fc32ee8f61dc301a93ecec4bdc689810 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 09:05:55 +0800 Subject: [PATCH] add an optional code_interpreter --- vision_agent/agent/vision_agent.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ba6e1d64..1a38468f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -197,6 +197,7 @@ def __init__( local_artifacts_path: Optional[Union[str, Path]] = None, code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[CodeInterpreter] = None, ) -> None: """Initialize the VisionAgent. @@ -207,12 +208,14 @@ def __init__( local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + code_interpreter (Optional[CodeInterpreter]): if not None, use this CodeInterpreter """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = code_interpreter self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) @@ -284,9 +287,14 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") - with CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime, - ) as code_interpreter: + code_interpreter = ( + self.code_interpreter + if self.code_interpreter is not None + else CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime, + ) + ) + with code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) last_user_message = chat[-1]