diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ef49086a..ba6e1d64 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -21,7 +21,7 @@ use_extra_vision_agent_args, ) from vision_agent.utils import CodeInterpreterFactory -from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution +from vision_agent.utils.execute import CodeInterpreter, Execution logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) @@ -29,7 +29,6 @@ WORKSPACE.mkdir(parents=True, exist_ok=True) if str(WORKSPACE) != "": os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}" -_SESSION_TIMEOUT = 600 # 10 minutes class BoilerplateCode: @@ -196,7 +195,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -207,17 +206,13 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity - self.code_interpreter = ( - code_interpreter - if code_interpreter is not None - else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) - ) + self.code_sandbox_runtime = code_sandbox_runtime self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) @@ -289,7 +284,9 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") - with self.code_interpreter as code_interpreter: + with CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime, + ) as code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) last_user_message = chat[-1] @@ -475,7 +472,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using OpenAI LMMs. @@ -486,7 +483,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent @@ -494,7 +491,7 @@ def __init__( agent, verbosity, local_artifacts_path, - code_interpreter, + code_sandbox_runtime, callback_message, ) @@ -505,7 +502,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using Anthropic LMMs. @@ -516,7 +513,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ agent = AnthropicLMM(temperature=0.0) if agent is None else agent @@ -524,6 +521,6 @@ def __init__( agent, verbosity, local_artifacts_path, - code_interpreter, + code_sandbox_runtime, callback_message, )