Skip to content

Commit 3ee6198

Browse files
committed
Revert "remove e2bcodeinterpreter and codefactory"
This reverts commit c67e225.
1 parent c6ef4ad commit 3ee6198

File tree

4 files changed

+217
-14
lines changed

4 files changed

+217
-14
lines changed

vision_agent/agent/vision_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
check_and_load_image,
2121
use_extra_vision_agent_args,
2222
)
23+
from vision_agent.utils import CodeInterpreterFactory
2324
from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution
2425

2526
logging.basicConfig(level=logging.INFO)

vision_agent/agent/vision_agent_coder.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
OpenAILMM,
3939
)
4040
from vision_agent.tools.meta_tools import get_diff
41-
from vision_agent.utils import Execution
42-
from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter
41+
from vision_agent.utils import CodeInterpreterFactory, Execution
42+
from vision_agent.utils.execute import CodeInterpreter
4343
from vision_agent.utils.image_utils import b64_to_pil
4444
from vision_agent.utils.sim import AzureSim, OllamaSim, Sim
4545
from vision_agent.utils.video import play_video
@@ -49,7 +49,6 @@
4949
_LOGGER = logging.getLogger(__name__)
5050
_MAX_TABULATE_COL_WIDTH = 80
5151
_CONSOLE = Console()
52-
_SESSION_TIMEOUT = 600 # 10 minutes
5352

5453

5554
class DefaultImports:
@@ -624,7 +623,7 @@ def __init__(
624623
tool_recommender: Optional[Sim] = None,
625624
verbosity: int = 0,
626625
report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
627-
code_interpreter: Optional[CodeInterpreter] = None,
626+
code_sandbox_runtime: Optional[str] = None,
628627
) -> None:
629628
"""Initialize the Vision Agent Coder.
630629
@@ -642,10 +641,11 @@ def __init__(
642641
in a web application where multiple VisionAgentCoder instances are
643642
running in parallel. This callback ensures that the progress are not
644643
mixed up.
645-
code_interpreter (Optional[CodeInterpreter]): the code interpreter to use. A
646-
code interpreter is used to run the generated code. It can be one of the
647-
following values: None, LocalCodeInterpreter or E2BCodeInterpreter.
648-
If None, LocalCodeInterpreter, which is the local python runtime environment will be used.
644+
code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A
645+
code sandbox is used to run the generated code. It can be one of the
646+
following values: None, "local" or "e2b". If None, VisionAgentCoder
647+
will read the value from the environment variable CODE_SANDBOX_RUNTIME.
648+
If it's also None, the local python runtime environment will be used.
649649
"""
650650

651651
self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner
@@ -662,11 +662,7 @@ def __init__(
662662
else tool_recommender
663663
)
664664
self.report_progress_callback = report_progress_callback
665-
self.code_interpreter = (
666-
code_interpreter
667-
if code_interpreter is not None
668-
else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT)
669-
)
665+
self.code_sandbox_runtime = code_sandbox_runtime
670666

671667
def __call__(
672668
self,
@@ -727,7 +723,9 @@ def chat_with_workflow(
727723
raise ValueError("Chat cannot be empty.")
728724

729725
# NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues
730-
with self.code_interpreter as code_interpreter:
726+
with CodeInterpreterFactory.new_instance(
727+
code_sandbox_runtime=self.code_sandbox_runtime
728+
) as code_interpreter:
731729
chat = copy.deepcopy(chat)
732730
media_list = []
733731
for chat_i in chat:

vision_agent/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .execute import (
22
CodeInterpreter,
3+
CodeInterpreterFactory,
34
Error,
45
Execution,
56
Logs,

vision_agent/utils/execute.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,23 @@
66
import re
77
import sys
88
import traceback
9+
import warnings
910
from enum import Enum
1011
from pathlib import Path
1112
from time import sleep
1213
from typing import Any, Dict, Iterable, List, Optional, Union
1314

1415
import nbformat
16+
import tenacity
1517
from dotenv import load_dotenv
18+
from e2b.exceptions import SandboxException
19+
from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl
1620
from e2b_code_interpreter import Execution as E2BExecution
1721
from e2b_code_interpreter import Result as E2BResult
22+
from h11._util import LocalProtocolError
23+
from httpx import ConnectError
24+
from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError
25+
from httpx import RemoteProtocolError as HttpxRemoteProtocolError
1826
from nbclient import NotebookClient
1927
from nbclient import __version__ as nbclient_version
2028
from nbclient.exceptions import CellTimeoutError, DeadKernelError
@@ -23,6 +31,11 @@
2331
from pydantic import BaseModel, field_serializer
2432
from typing_extensions import Self
2533

34+
from vision_agent.utils.exceptions import (
35+
RemoteSandboxCreationError,
36+
RemoteSandboxExecutionError,
37+
)
38+
2639
load_dotenv()
2740
_LOGGER = logging.getLogger(__name__)
2841
_SESSION_TIMEOUT = 600 # 10 minutes
@@ -421,6 +434,138 @@ def download_file(
421434
return Path(local_file_path)
422435

423436

437+
class E2BCodeInterpreter(CodeInterpreter):
438+
def __init__(
439+
self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any
440+
) -> None:
441+
super().__init__(*args, **kwargs)
442+
assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set"
443+
try:
444+
self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl(
445+
*args, **kwargs
446+
)
447+
except Exception as e:
448+
raise RemoteSandboxCreationError(
449+
f"Failed to create a remote sandbox due to {e}"
450+
) from e
451+
452+
result = self.exec_cell(
453+
"""
454+
import platform
455+
import sys
456+
import importlib.metadata
457+
458+
print(f"Python version: {sys.version}")
459+
print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})")
460+
va_version = importlib.metadata.version("vision-agent")
461+
print(f"Vision Agent version: {va_version}")"""
462+
)
463+
sys_versions = "\n".join(result.logs.stdout)
464+
_LOGGER.info(
465+
f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}"
466+
)
467+
self.remote_path = Path(
468+
remote_path if remote_path is not None else "/home/user"
469+
)
470+
471+
def close(self, *args: Any, **kwargs: Any) -> None:
472+
try:
473+
self.interpreter.kill(request_timeout=2)
474+
_LOGGER.info(
475+
f"The sandbox {self.interpreter.sandbox_id} is closed successfully."
476+
)
477+
except Exception as e:
478+
_LOGGER.warn(
479+
f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout."
480+
)
481+
482+
def restart_kernel(self) -> None:
483+
self.interpreter.notebook.restart_kernel()
484+
485+
@tenacity.retry(
486+
wait=tenacity.wait_exponential_jitter(),
487+
stop=tenacity.stop_after_attempt(3),
488+
retry=tenacity.retry_if_exception_type(
489+
(
490+
LocalProtocolError,
491+
HttpxRemoteProtocolError,
492+
HttpcoreRemoteProtocolError,
493+
ConnectError,
494+
SandboxException,
495+
)
496+
),
497+
before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO),
498+
after=tenacity.after_log(_LOGGER, logging.INFO),
499+
)
500+
def exec_cell(self, code: str) -> Execution:
501+
self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox
502+
try:
503+
_LOGGER.info(
504+
f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}"
505+
)
506+
execution = self.interpreter.notebook.exec_cell(
507+
code=code,
508+
on_stdout=lambda msg: _LOGGER.info(msg),
509+
on_stderr=lambda msg: _LOGGER.info(msg),
510+
)
511+
_LOGGER.info(
512+
f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}"
513+
)
514+
return Execution.from_e2b_execution(execution)
515+
except (
516+
LocalProtocolError,
517+
HttpxRemoteProtocolError,
518+
HttpcoreRemoteProtocolError,
519+
ConnectError,
520+
SandboxException,
521+
) as e:
522+
raise e
523+
except Exception as e:
524+
raise RemoteSandboxExecutionError(
525+
f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}"
526+
) from e
527+
528+
@tenacity.retry(
529+
wait=tenacity.wait_exponential_jitter(),
530+
stop=tenacity.stop_after_attempt(3),
531+
retry=tenacity.retry_if_exception_type(
532+
(
533+
LocalProtocolError,
534+
HttpxRemoteProtocolError,
535+
HttpcoreRemoteProtocolError,
536+
ConnectError,
537+
SandboxException,
538+
)
539+
),
540+
before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO),
541+
after=tenacity.after_log(_LOGGER, logging.INFO),
542+
)
543+
def upload_file(self, file: Union[str, Path]) -> Path:
544+
file_name = Path(file).name
545+
with open(file, "rb") as f:
546+
self.interpreter.files.write(path=str(self.remote_path / file_name), data=f)
547+
_LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}")
548+
return self.remote_path / file_name
549+
550+
def download_file(
551+
self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path]
552+
) -> Path:
553+
with open(local_file_path, "w+b") as f:
554+
f.write(
555+
self.interpreter.files.read(path=str(remote_file_path), format="bytes")
556+
)
557+
_LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}")
558+
return Path(local_file_path)
559+
560+
@staticmethod
561+
def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore
562+
template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox")
563+
_LOGGER.info(
564+
f"Creating a new E2BCodeInterpreter using template: {template_name}"
565+
)
566+
return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs)
567+
568+
424569
class LocalCodeInterpreter(CodeInterpreter):
425570
def __init__(
426571
self,
@@ -522,6 +667,64 @@ def download_file(
522667
return Path(local_file_path)
523668

524669

670+
class CodeInterpreterFactory:
671+
"""Factory class for creating code interpreters.
672+
Could be extended to support multiple code interpreters.
673+
"""
674+
675+
_instance_map: Dict[str, CodeInterpreter] = {}
676+
_default_key = "default"
677+
678+
@staticmethod
679+
def get_default_instance() -> CodeInterpreter:
680+
warnings.warn(
681+
"Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future."
682+
)
683+
inst_map = CodeInterpreterFactory._instance_map
684+
instance = inst_map.get(CodeInterpreterFactory._default_key)
685+
if instance:
686+
return instance
687+
instance = CodeInterpreterFactory.new_instance()
688+
inst_map[CodeInterpreterFactory._default_key] = instance
689+
return instance
690+
691+
@staticmethod
692+
def new_instance(
693+
code_sandbox_runtime: Optional[str] = None,
694+
remote_path: Optional[Union[str, Path]] = None,
695+
) -> CodeInterpreter:
696+
if not code_sandbox_runtime:
697+
code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local")
698+
if code_sandbox_runtime == "e2b":
699+
envs = _get_e2b_env()
700+
instance: CodeInterpreter = E2BCodeInterpreter(
701+
timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs
702+
)
703+
elif code_sandbox_runtime == "local":
704+
instance = LocalCodeInterpreter(
705+
timeout=_SESSION_TIMEOUT, remote_path=remote_path
706+
)
707+
else:
708+
raise ValueError(
709+
f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local"
710+
)
711+
return instance
712+
713+
714+
def _get_e2b_env() -> Union[Dict[str, str], None]:
715+
openai_api_key = os.getenv("OPENAI_API_KEY", "")
716+
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "")
717+
if openai_api_key or anthropic_api_key:
718+
envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")}
719+
if openai_api_key:
720+
envs["OPENAI_API_KEY"] = openai_api_key
721+
if anthropic_api_key:
722+
envs["ANTHROPIC_API_KEY"] = anthropic_api_key
723+
else:
724+
envs = None
725+
return envs
726+
727+
525728
def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution:
526729
"""Parse notebook cell outputs to Execution object. Output types:
527730
https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs

0 commit comments

Comments
 (0)