diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 98e70fa8..ef49086a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -20,6 +20,7 @@ check_and_load_image, use_extra_vision_agent_args, ) +from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution logging.basicConfig(level=logging.INFO) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index e8dc2247..1e5030a2 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -38,8 +38,8 @@ OpenAILMM, ) from vision_agent.tools.meta_tools import get_diff -from vision_agent.utils import Execution -from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter +from vision_agent.utils import CodeInterpreterFactory, Execution +from vision_agent.utils.execute import CodeInterpreter from vision_agent.utils.image_utils import b64_to_pil from vision_agent.utils.sim import AzureSim, OllamaSim, Sim from vision_agent.utils.video import play_video @@ -49,7 +49,6 @@ _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 _CONSOLE = Console() -_SESSION_TIMEOUT = 600 # 10 minutes class DefaultImports: @@ -624,7 +623,7 @@ def __init__( tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, ) -> None: """Initialize the Vision Agent Coder. @@ -642,10 +641,11 @@ def __init__( in a web application where multiple VisionAgentCoder instances are running in parallel. This callback ensures that the progress are not mixed up. - code_interpreter (Optional[CodeInterpreter]): the code interpreter to use. A - code interpreter is used to run the generated code. It can be one of the - following values: None, LocalCodeInterpreter or E2BCodeInterpreter. - If None, LocalCodeInterpreter, which is the local python runtime environment will be used. + code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A + code sandbox is used to run the generated code. It can be one of the + following values: None, "local" or "e2b". If None, VisionAgentCoder + will read the value from the environment variable CODE_SANDBOX_RUNTIME. + If it's also None, the local python runtime environment will be used. """ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner @@ -662,11 +662,7 @@ def __init__( else tool_recommender ) self.report_progress_callback = report_progress_callback - self.code_interpreter = ( - code_interpreter - if code_interpreter is not None - else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) - ) + self.code_sandbox_runtime = code_sandbox_runtime def __call__( self, @@ -727,7 +723,9 @@ def chat_with_workflow( raise ValueError("Chat cannot be empty.") # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues - with self.code_interpreter as code_interpreter: + with CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime + ) as code_interpreter: chat = copy.deepcopy(chat) media_list = [] for chat_i in chat: diff --git a/vision_agent/utils/__init__.py b/vision_agent/utils/__init__.py index 819477c1..2810713a 100644 --- a/vision_agent/utils/__init__.py +++ b/vision_agent/utils/__init__.py @@ -1,5 +1,6 @@ from .execute import ( CodeInterpreter, + CodeInterpreterFactory, Error, Execution, Logs, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 38c38150..1907d106 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -6,15 +6,23 @@ import re import sys import traceback +import warnings from enum import Enum from pathlib import Path from time import sleep from typing import Any, Dict, Iterable, List, Optional, Union import nbformat +import tenacity from dotenv import load_dotenv +from e2b.exceptions import SandboxException +from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl from e2b_code_interpreter import Execution as E2BExecution from e2b_code_interpreter import Result as E2BResult +from h11._util import LocalProtocolError +from httpx import ConnectError +from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError +from httpx import RemoteProtocolError as HttpxRemoteProtocolError from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -23,6 +31,11 @@ from pydantic import BaseModel, field_serializer from typing_extensions import Self +from vision_agent.utils.exceptions import ( + RemoteSandboxCreationError, + RemoteSandboxExecutionError, +) + load_dotenv() _LOGGER = logging.getLogger(__name__) _SESSION_TIMEOUT = 600 # 10 minutes @@ -421,6 +434,138 @@ def download_file( return Path(local_file_path) +class E2BCodeInterpreter(CodeInterpreter): + def __init__( + self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" + try: + self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl( + *args, **kwargs + ) + except Exception as e: + raise RemoteSandboxCreationError( + f"Failed to create a remote sandbox due to {e}" + ) from e + + result = self.exec_cell( + """ +import platform +import sys +import importlib.metadata + +print(f"Python version: {sys.version}") +print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})") +va_version = importlib.metadata.version("vision-agent") +print(f"Vision Agent version: {va_version}")""" + ) + sys_versions = "\n".join(result.logs.stdout) + _LOGGER.info( + f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}" + ) + self.remote_path = Path( + remote_path if remote_path is not None else "/home/user" + ) + + def close(self, *args: Any, **kwargs: Any) -> None: + try: + self.interpreter.kill(request_timeout=2) + _LOGGER.info( + f"The sandbox {self.interpreter.sandbox_id} is closed successfully." + ) + except Exception as e: + _LOGGER.warn( + f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout." + ) + + def restart_kernel(self) -> None: + self.interpreter.notebook.restart_kernel() + + @tenacity.retry( + wait=tenacity.wait_exponential_jitter(), + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type( + ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) + ), + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), + after=tenacity.after_log(_LOGGER, logging.INFO), + ) + def exec_cell(self, code: str) -> Execution: + self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox + try: + _LOGGER.info( + f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}" + ) + execution = self.interpreter.notebook.exec_cell( + code=code, + on_stdout=lambda msg: _LOGGER.info(msg), + on_stderr=lambda msg: _LOGGER.info(msg), + ) + _LOGGER.info( + f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}" + ) + return Execution.from_e2b_execution(execution) + except ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) as e: + raise e + except Exception as e: + raise RemoteSandboxExecutionError( + f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}" + ) from e + + @tenacity.retry( + wait=tenacity.wait_exponential_jitter(), + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type( + ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) + ), + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), + after=tenacity.after_log(_LOGGER, logging.INFO), + ) + def upload_file(self, file: Union[str, Path]) -> Path: + file_name = Path(file).name + with open(file, "rb") as f: + self.interpreter.files.write(path=str(self.remote_path / file_name), data=f) + _LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}") + return self.remote_path / file_name + + def download_file( + self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path] + ) -> Path: + with open(local_file_path, "w+b") as f: + f.write( + self.interpreter.files.read(path=str(remote_file_path), format="bytes") + ) + _LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}") + return Path(local_file_path) + + @staticmethod + def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore + template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox") + _LOGGER.info( + f"Creating a new E2BCodeInterpreter using template: {template_name}" + ) + return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs) + + class LocalCodeInterpreter(CodeInterpreter): def __init__( self, @@ -522,6 +667,64 @@ def download_file( return Path(local_file_path) +class CodeInterpreterFactory: + """Factory class for creating code interpreters. + Could be extended to support multiple code interpreters. + """ + + _instance_map: Dict[str, CodeInterpreter] = {} + _default_key = "default" + + @staticmethod + def get_default_instance() -> CodeInterpreter: + warnings.warn( + "Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future." + ) + inst_map = CodeInterpreterFactory._instance_map + instance = inst_map.get(CodeInterpreterFactory._default_key) + if instance: + return instance + instance = CodeInterpreterFactory.new_instance() + inst_map[CodeInterpreterFactory._default_key] = instance + return instance + + @staticmethod + def new_instance( + code_sandbox_runtime: Optional[str] = None, + remote_path: Optional[Union[str, Path]] = None, + ) -> CodeInterpreter: + if not code_sandbox_runtime: + code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local") + if code_sandbox_runtime == "e2b": + envs = _get_e2b_env() + instance: CodeInterpreter = E2BCodeInterpreter( + timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs + ) + elif code_sandbox_runtime == "local": + instance = LocalCodeInterpreter( + timeout=_SESSION_TIMEOUT, remote_path=remote_path + ) + else: + raise ValueError( + f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local" + ) + return instance + + +def _get_e2b_env() -> Union[Dict[str, str], None]: + openai_api_key = os.getenv("OPENAI_API_KEY", "") + anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") + if openai_api_key or anthropic_api_key: + envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")} + if openai_api_key: + envs["OPENAI_API_KEY"] = openai_api_key + if anthropic_api_key: + envs["ANTHROPIC_API_KEY"] = anthropic_api_key + else: + envs = None + return envs + + def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution: """Parse notebook cell outputs to Execution object. Output types: https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs