Skip to content

Commit

Permalink
Revert "remove e2bcodeinterpreter and codefactory"
Browse files Browse the repository at this point in the history
This reverts commit c67e225.
  • Loading branch information
yzld2002 committed Oct 11, 2024
1 parent c6ef4ad commit 3ee6198
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 14 deletions.
1 change: 1 addition & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
check_and_load_image,
use_extra_vision_agent_args,
)
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution

logging.basicConfig(level=logging.INFO)
Expand Down
26 changes: 12 additions & 14 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
OpenAILMM,
)
from vision_agent.tools.meta_tools import get_diff
from vision_agent.utils import Execution
from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter
from vision_agent.utils import CodeInterpreterFactory, Execution
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
from vision_agent.utils.video import play_video
Expand All @@ -49,7 +49,6 @@
_LOGGER = logging.getLogger(__name__)
_MAX_TABULATE_COL_WIDTH = 80
_CONSOLE = Console()
_SESSION_TIMEOUT = 600 # 10 minutes


class DefaultImports:
Expand Down Expand Up @@ -624,7 +623,7 @@ def __init__(
tool_recommender: Optional[Sim] = None,
verbosity: int = 0,
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
code_interpreter: Optional[CodeInterpreter] = None,
code_sandbox_runtime: Optional[str] = None,
) -> None:
"""Initialize the Vision Agent Coder.
Expand All @@ -642,10 +641,11 @@ def __init__(
in a web application where multiple VisionAgentCoder instances are
running in parallel. This callback ensures that the progress are not
mixed up.
code_interpreter (Optional[CodeInterpreter]): the code interpreter to use. A
code interpreter is used to run the generated code. It can be one of the
following values: None, LocalCodeInterpreter or E2BCodeInterpreter.
If None, LocalCodeInterpreter, which is the local python runtime environment will be used.
code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
code sandbox is used to run the generated code. It can be one of the
following values: None, "local" or "e2b". If None, VisionAgentCoder
will read the value from the environment variable CODE_SANDBOX_RUNTIME.
If it's also None, the local python runtime environment will be used.
"""

self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
Expand All @@ -662,11 +662,7 @@ def __init__(
else tool_recommender
)
self.report_progress_callback = report_progress_callback
self.code_interpreter = (
code_interpreter
if code_interpreter is not None
else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT)
)
self.code_sandbox_runtime = code_sandbox_runtime

def __call__(
self,
Expand Down Expand Up @@ -727,7 +723,9 @@ def chat_with_workflow(
raise ValueError("Chat cannot be empty.")

# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
with self.code_interpreter as code_interpreter:
with CodeInterpreterFactory.new_instance(
code_sandbox_runtime=self.code_sandbox_runtime
) as code_interpreter:
chat = copy.deepcopy(chat)
media_list = []
for chat_i in chat:
Expand Down
1 change: 1 addition & 0 deletions vision_agent/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .execute import (
CodeInterpreter,
CodeInterpreterFactory,
Error,
Execution,
Logs,
Expand Down
203 changes: 203 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
import re
import sys
import traceback
import warnings
from enum import Enum
from pathlib import Path
from time import sleep
from typing import Any, Dict, Iterable, List, Optional, Union

import nbformat
import tenacity
from dotenv import load_dotenv
from e2b.exceptions import SandboxException
from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl
from e2b_code_interpreter import Execution as E2BExecution
from e2b_code_interpreter import Result as E2BResult
from h11._util import LocalProtocolError
from httpx import ConnectError
from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError
from httpx import RemoteProtocolError as HttpxRemoteProtocolError
from nbclient import NotebookClient
from nbclient import __version__ as nbclient_version
from nbclient.exceptions import CellTimeoutError, DeadKernelError
Expand All @@ -23,6 +31,11 @@
from pydantic import BaseModel, field_serializer
from typing_extensions import Self

from vision_agent.utils.exceptions import (
RemoteSandboxCreationError,
RemoteSandboxExecutionError,
)

load_dotenv()
_LOGGER = logging.getLogger(__name__)
_SESSION_TIMEOUT = 600 # 10 minutes
Expand Down Expand Up @@ -421,6 +434,138 @@ def download_file(
return Path(local_file_path)


class E2BCodeInterpreter(CodeInterpreter):
def __init__(
self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set"
try:
self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl(
*args, **kwargs
)
except Exception as e:
raise RemoteSandboxCreationError(
f"Failed to create a remote sandbox due to {e}"
) from e

result = self.exec_cell(
"""
import platform
import sys
import importlib.metadata
print(f"Python version: {sys.version}")
print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})")
va_version = importlib.metadata.version("vision-agent")
print(f"Vision Agent version: {va_version}")"""
)
sys_versions = "\n".join(result.logs.stdout)
_LOGGER.info(
f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}"
)
self.remote_path = Path(
remote_path if remote_path is not None else "/home/user"
)

def close(self, *args: Any, **kwargs: Any) -> None:
try:
self.interpreter.kill(request_timeout=2)
_LOGGER.info(
f"The sandbox {self.interpreter.sandbox_id} is closed successfully."
)
except Exception as e:
_LOGGER.warn(
f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout."
)

def restart_kernel(self) -> None:
self.interpreter.notebook.restart_kernel()

@tenacity.retry(
wait=tenacity.wait_exponential_jitter(),
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(
(
LocalProtocolError,
HttpxRemoteProtocolError,
HttpcoreRemoteProtocolError,
ConnectError,
SandboxException,
)
),
before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO),
after=tenacity.after_log(_LOGGER, logging.INFO),
)
def exec_cell(self, code: str) -> Execution:
self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox
try:
_LOGGER.info(
f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}"
)
execution = self.interpreter.notebook.exec_cell(
code=code,
on_stdout=lambda msg: _LOGGER.info(msg),
on_stderr=lambda msg: _LOGGER.info(msg),
)
_LOGGER.info(
f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}"
)
return Execution.from_e2b_execution(execution)
except (
LocalProtocolError,
HttpxRemoteProtocolError,
HttpcoreRemoteProtocolError,
ConnectError,
SandboxException,
) as e:
raise e
except Exception as e:
raise RemoteSandboxExecutionError(
f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}"
) from e

@tenacity.retry(
wait=tenacity.wait_exponential_jitter(),
stop=tenacity.stop_after_attempt(3),
retry=tenacity.retry_if_exception_type(
(
LocalProtocolError,
HttpxRemoteProtocolError,
HttpcoreRemoteProtocolError,
ConnectError,
SandboxException,
)
),
before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO),
after=tenacity.after_log(_LOGGER, logging.INFO),
)
def upload_file(self, file: Union[str, Path]) -> Path:
file_name = Path(file).name
with open(file, "rb") as f:
self.interpreter.files.write(path=str(self.remote_path / file_name), data=f)
_LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}")
return self.remote_path / file_name

def download_file(
self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
) -> Path:
with open(local_file_path, "w+b") as f:
f.write(
self.interpreter.files.read(path=str(remote_file_path), format="bytes")
)
_LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}")
return Path(local_file_path)

@staticmethod
def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore
template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox")
_LOGGER.info(
f"Creating a new E2BCodeInterpreter using template: {template_name}"
)
return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs)


class LocalCodeInterpreter(CodeInterpreter):
def __init__(
self,
Expand Down Expand Up @@ -522,6 +667,64 @@ def download_file(
return Path(local_file_path)


class CodeInterpreterFactory:
"""Factory class for creating code interpreters.
Could be extended to support multiple code interpreters.
"""

_instance_map: Dict[str, CodeInterpreter] = {}
_default_key = "default"

@staticmethod
def get_default_instance() -> CodeInterpreter:
warnings.warn(
"Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future."
)
inst_map = CodeInterpreterFactory._instance_map
instance = inst_map.get(CodeInterpreterFactory._default_key)
if instance:
return instance
instance = CodeInterpreterFactory.new_instance()
inst_map[CodeInterpreterFactory._default_key] = instance
return instance

@staticmethod
def new_instance(
code_sandbox_runtime: Optional[str] = None,
remote_path: Optional[Union[str, Path]] = None,
) -> CodeInterpreter:
if not code_sandbox_runtime:
code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
if code_sandbox_runtime == "e2b":
envs = _get_e2b_env()
instance: CodeInterpreter = E2BCodeInterpreter(
timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs
)
elif code_sandbox_runtime == "local":
instance = LocalCodeInterpreter(
timeout=_SESSION_TIMEOUT, remote_path=remote_path
)
else:
raise ValueError(
f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local"
)
return instance


def _get_e2b_env() -> Union[Dict[str, str], None]:
openai_api_key = os.getenv("OPENAI_API_KEY", "")
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
if openai_api_key or anthropic_api_key:
envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")}
if openai_api_key:
envs["OPENAI_API_KEY"] = openai_api_key
if anthropic_api_key:
envs["ANTHROPIC_API_KEY"] = anthropic_api_key
else:
envs = None
return envs


def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution:
"""Parse notebook cell outputs to Execution object. Output types:
https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs
Expand Down

0 comments on commit 3ee6198

Please sign in to comment.