|
6 | 6 | import re
|
7 | 7 | import sys
|
8 | 8 | import traceback
|
| 9 | +import warnings |
9 | 10 | from enum import Enum
|
10 | 11 | from pathlib import Path
|
11 | 12 | from time import sleep
|
12 | 13 | from typing import Any, Dict, Iterable, List, Optional, Union
|
13 | 14 |
|
14 | 15 | import nbformat
|
| 16 | +import tenacity |
15 | 17 | from dotenv import load_dotenv
|
| 18 | +from e2b.exceptions import SandboxException |
| 19 | +from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl |
16 | 20 | from e2b_code_interpreter import Execution as E2BExecution
|
17 | 21 | from e2b_code_interpreter import Result as E2BResult
|
| 22 | +from h11._util import LocalProtocolError |
| 23 | +from httpx import ConnectError |
| 24 | +from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError |
| 25 | +from httpx import RemoteProtocolError as HttpxRemoteProtocolError |
18 | 26 | from nbclient import NotebookClient
|
19 | 27 | from nbclient import __version__ as nbclient_version
|
20 | 28 | from nbclient.exceptions import CellTimeoutError, DeadKernelError
|
|
23 | 31 | from pydantic import BaseModel, field_serializer
|
24 | 32 | from typing_extensions import Self
|
25 | 33 |
|
| 34 | +from vision_agent.utils.exceptions import ( |
| 35 | + RemoteSandboxCreationError, |
| 36 | + RemoteSandboxExecutionError, |
| 37 | +) |
| 38 | + |
26 | 39 | load_dotenv()
|
27 | 40 | _LOGGER = logging.getLogger(__name__)
|
28 | 41 | _SESSION_TIMEOUT = 600 # 10 minutes
|
@@ -421,6 +434,138 @@ def download_file(
|
421 | 434 | return Path(local_file_path)
|
422 | 435 |
|
423 | 436 |
|
| 437 | +class E2BCodeInterpreter(CodeInterpreter): |
| 438 | + def __init__( |
| 439 | + self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any |
| 440 | + ) -> None: |
| 441 | + super().__init__(*args, **kwargs) |
| 442 | + assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" |
| 443 | + try: |
| 444 | + self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl( |
| 445 | + *args, **kwargs |
| 446 | + ) |
| 447 | + except Exception as e: |
| 448 | + raise RemoteSandboxCreationError( |
| 449 | + f"Failed to create a remote sandbox due to {e}" |
| 450 | + ) from e |
| 451 | + |
| 452 | + result = self.exec_cell( |
| 453 | + """ |
| 454 | +import platform |
| 455 | +import sys |
| 456 | +import importlib.metadata |
| 457 | +
|
| 458 | +print(f"Python version: {sys.version}") |
| 459 | +print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})") |
| 460 | +va_version = importlib.metadata.version("vision-agent") |
| 461 | +print(f"Vision Agent version: {va_version}")""" |
| 462 | + ) |
| 463 | + sys_versions = "\n".join(result.logs.stdout) |
| 464 | + _LOGGER.info( |
| 465 | + f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}" |
| 466 | + ) |
| 467 | + self.remote_path = Path( |
| 468 | + remote_path if remote_path is not None else "/home/user" |
| 469 | + ) |
| 470 | + |
| 471 | + def close(self, *args: Any, **kwargs: Any) -> None: |
| 472 | + try: |
| 473 | + self.interpreter.kill(request_timeout=2) |
| 474 | + _LOGGER.info( |
| 475 | + f"The sandbox {self.interpreter.sandbox_id} is closed successfully." |
| 476 | + ) |
| 477 | + except Exception as e: |
| 478 | + _LOGGER.warn( |
| 479 | + f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout." |
| 480 | + ) |
| 481 | + |
| 482 | + def restart_kernel(self) -> None: |
| 483 | + self.interpreter.notebook.restart_kernel() |
| 484 | + |
| 485 | + @tenacity.retry( |
| 486 | + wait=tenacity.wait_exponential_jitter(), |
| 487 | + stop=tenacity.stop_after_attempt(3), |
| 488 | + retry=tenacity.retry_if_exception_type( |
| 489 | + ( |
| 490 | + LocalProtocolError, |
| 491 | + HttpxRemoteProtocolError, |
| 492 | + HttpcoreRemoteProtocolError, |
| 493 | + ConnectError, |
| 494 | + SandboxException, |
| 495 | + ) |
| 496 | + ), |
| 497 | + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), |
| 498 | + after=tenacity.after_log(_LOGGER, logging.INFO), |
| 499 | + ) |
| 500 | + def exec_cell(self, code: str) -> Execution: |
| 501 | + self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox |
| 502 | + try: |
| 503 | + _LOGGER.info( |
| 504 | + f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}" |
| 505 | + ) |
| 506 | + execution = self.interpreter.notebook.exec_cell( |
| 507 | + code=code, |
| 508 | + on_stdout=lambda msg: _LOGGER.info(msg), |
| 509 | + on_stderr=lambda msg: _LOGGER.info(msg), |
| 510 | + ) |
| 511 | + _LOGGER.info( |
| 512 | + f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}" |
| 513 | + ) |
| 514 | + return Execution.from_e2b_execution(execution) |
| 515 | + except ( |
| 516 | + LocalProtocolError, |
| 517 | + HttpxRemoteProtocolError, |
| 518 | + HttpcoreRemoteProtocolError, |
| 519 | + ConnectError, |
| 520 | + SandboxException, |
| 521 | + ) as e: |
| 522 | + raise e |
| 523 | + except Exception as e: |
| 524 | + raise RemoteSandboxExecutionError( |
| 525 | + f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}" |
| 526 | + ) from e |
| 527 | + |
| 528 | + @tenacity.retry( |
| 529 | + wait=tenacity.wait_exponential_jitter(), |
| 530 | + stop=tenacity.stop_after_attempt(3), |
| 531 | + retry=tenacity.retry_if_exception_type( |
| 532 | + ( |
| 533 | + LocalProtocolError, |
| 534 | + HttpxRemoteProtocolError, |
| 535 | + HttpcoreRemoteProtocolError, |
| 536 | + ConnectError, |
| 537 | + SandboxException, |
| 538 | + ) |
| 539 | + ), |
| 540 | + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), |
| 541 | + after=tenacity.after_log(_LOGGER, logging.INFO), |
| 542 | + ) |
| 543 | + def upload_file(self, file: Union[str, Path]) -> Path: |
| 544 | + file_name = Path(file).name |
| 545 | + with open(file, "rb") as f: |
| 546 | + self.interpreter.files.write(path=str(self.remote_path / file_name), data=f) |
| 547 | + _LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}") |
| 548 | + return self.remote_path / file_name |
| 549 | + |
| 550 | + def download_file( |
| 551 | + self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path] |
| 552 | + ) -> Path: |
| 553 | + with open(local_file_path, "w+b") as f: |
| 554 | + f.write( |
| 555 | + self.interpreter.files.read(path=str(remote_file_path), format="bytes") |
| 556 | + ) |
| 557 | + _LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}") |
| 558 | + return Path(local_file_path) |
| 559 | + |
| 560 | + @staticmethod |
| 561 | + def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore |
| 562 | + template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox") |
| 563 | + _LOGGER.info( |
| 564 | + f"Creating a new E2BCodeInterpreter using template: {template_name}" |
| 565 | + ) |
| 566 | + return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs) |
| 567 | + |
| 568 | + |
424 | 569 | class LocalCodeInterpreter(CodeInterpreter):
|
425 | 570 | def __init__(
|
426 | 571 | self,
|
@@ -522,6 +667,64 @@ def download_file(
|
522 | 667 | return Path(local_file_path)
|
523 | 668 |
|
524 | 669 |
|
| 670 | +class CodeInterpreterFactory: |
| 671 | + """Factory class for creating code interpreters. |
| 672 | + Could be extended to support multiple code interpreters. |
| 673 | + """ |
| 674 | + |
| 675 | + _instance_map: Dict[str, CodeInterpreter] = {} |
| 676 | + _default_key = "default" |
| 677 | + |
| 678 | + @staticmethod |
| 679 | + def get_default_instance() -> CodeInterpreter: |
| 680 | + warnings.warn( |
| 681 | + "Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future." |
| 682 | + ) |
| 683 | + inst_map = CodeInterpreterFactory._instance_map |
| 684 | + instance = inst_map.get(CodeInterpreterFactory._default_key) |
| 685 | + if instance: |
| 686 | + return instance |
| 687 | + instance = CodeInterpreterFactory.new_instance() |
| 688 | + inst_map[CodeInterpreterFactory._default_key] = instance |
| 689 | + return instance |
| 690 | + |
| 691 | + @staticmethod |
| 692 | + def new_instance( |
| 693 | + code_sandbox_runtime: Optional[str] = None, |
| 694 | + remote_path: Optional[Union[str, Path]] = None, |
| 695 | + ) -> CodeInterpreter: |
| 696 | + if not code_sandbox_runtime: |
| 697 | + code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local") |
| 698 | + if code_sandbox_runtime == "e2b": |
| 699 | + envs = _get_e2b_env() |
| 700 | + instance: CodeInterpreter = E2BCodeInterpreter( |
| 701 | + timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs |
| 702 | + ) |
| 703 | + elif code_sandbox_runtime == "local": |
| 704 | + instance = LocalCodeInterpreter( |
| 705 | + timeout=_SESSION_TIMEOUT, remote_path=remote_path |
| 706 | + ) |
| 707 | + else: |
| 708 | + raise ValueError( |
| 709 | + f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local" |
| 710 | + ) |
| 711 | + return instance |
| 712 | + |
| 713 | + |
| 714 | +def _get_e2b_env() -> Union[Dict[str, str], None]: |
| 715 | + openai_api_key = os.getenv("OPENAI_API_KEY", "") |
| 716 | + anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") |
| 717 | + if openai_api_key or anthropic_api_key: |
| 718 | + envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")} |
| 719 | + if openai_api_key: |
| 720 | + envs["OPENAI_API_KEY"] = openai_api_key |
| 721 | + if anthropic_api_key: |
| 722 | + envs["ANTHROPIC_API_KEY"] = anthropic_api_key |
| 723 | + else: |
| 724 | + envs = None |
| 725 | + return envs |
| 726 | + |
| 727 | + |
525 | 728 | def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution:
|
526 | 729 | """Parse notebook cell outputs to Execution object. Output types:
|
527 | 730 | https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs
|
|
0 commit comments