From c66a4639b39951cf7af1d8b8e6d6cf05754ca0d2 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 16:41:51 +0800 Subject: [PATCH 01/12] calling --- vision_agent/agent/vision_agent.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ba6e1d64..ef49086a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -21,7 +21,7 @@ use_extra_vision_agent_args, ) from vision_agent.utils import CodeInterpreterFactory -from vision_agent.utils.execute import CodeInterpreter, Execution +from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) @@ -29,6 +29,7 @@ WORKSPACE.mkdir(parents=True, exist_ok=True) if str(WORKSPACE) != "": os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}" +_SESSION_TIMEOUT = 600 # 10 minutes class BoilerplateCode: @@ -195,7 +196,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[CodeInterpreter] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -206,13 +207,17 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity - self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = ( + code_interpreter + if code_interpreter is not None + else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) + ) self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) @@ -284,9 +289,7 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") - with CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime, - ) as code_interpreter: + with self.code_interpreter as code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) last_user_message = chat[-1] @@ -472,7 +475,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[CodeInterpreter] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using OpenAI LMMs. @@ -483,7 +486,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python """ agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent @@ -491,7 +494,7 @@ def __init__( agent, verbosity, local_artifacts_path, - code_sandbox_runtime, + code_interpreter, callback_message, ) @@ -502,7 +505,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[CodeInterpreter] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using Anthropic LMMs. @@ -513,7 +516,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python """ agent = AnthropicLMM(temperature=0.0) if agent is None else agent @@ -521,6 +524,6 @@ def __init__( agent, verbosity, local_artifacts_path, - code_sandbox_runtime, + code_interpreter, callback_message, ) From c67e2250d7da93ab4dda6409c2aae65d4999652a Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 21:41:08 +0800 Subject: [PATCH 02/12] remove e2bcodeinterpreter and codefactory --- vision_agent/agent/vision_agent.py | 1 - vision_agent/agent/vision_agent_coder.py | 26 +-- vision_agent/utils/__init__.py | 1 - vision_agent/utils/execute.py | 203 ----------------------- 4 files changed, 14 insertions(+), 217 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ef49086a..98e70fa8 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -20,7 +20,6 @@ check_and_load_image, use_extra_vision_agent_args, ) -from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution logging.basicConfig(level=logging.INFO) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 1e5030a2..e8dc2247 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -38,8 +38,8 @@ OpenAILMM, ) from vision_agent.tools.meta_tools import get_diff -from vision_agent.utils import CodeInterpreterFactory, Execution -from vision_agent.utils.execute import CodeInterpreter +from vision_agent.utils import Execution +from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter from vision_agent.utils.image_utils import b64_to_pil from vision_agent.utils.sim import AzureSim, OllamaSim, Sim from vision_agent.utils.video import play_video @@ -49,6 +49,7 @@ _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 _CONSOLE = Console() +_SESSION_TIMEOUT = 600 # 10 minutes class DefaultImports: @@ -623,7 +624,7 @@ def __init__( tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_sandbox_runtime: Optional[str] = None, + code_interpreter: Optional[CodeInterpreter] = None, ) -> None: """Initialize the Vision Agent Coder. @@ -641,11 +642,10 @@ def __init__( in a web application where multiple VisionAgentCoder instances are running in parallel. This callback ensures that the progress are not mixed up. - code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A - code sandbox is used to run the generated code. It can be one of the - following values: None, "local" or "e2b". If None, VisionAgentCoder - will read the value from the environment variable CODE_SANDBOX_RUNTIME. - If it's also None, the local python runtime environment will be used. + code_interpreter (Optional[CodeInterpreter]): the code interpreter to use. A + code interpreter is used to run the generated code. It can be one of the + following values: None, LocalCodeInterpreter or E2BCodeInterpreter. + If None, LocalCodeInterpreter, which is the local python runtime environment will be used. """ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner @@ -662,7 +662,11 @@ def __init__( else tool_recommender ) self.report_progress_callback = report_progress_callback - self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = ( + code_interpreter + if code_interpreter is not None + else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) + ) def __call__( self, @@ -723,9 +727,7 @@ def chat_with_workflow( raise ValueError("Chat cannot be empty.") # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues - with CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime - ) as code_interpreter: + with self.code_interpreter as code_interpreter: chat = copy.deepcopy(chat) media_list = [] for chat_i in chat: diff --git a/vision_agent/utils/__init__.py b/vision_agent/utils/__init__.py index 2810713a..819477c1 100644 --- a/vision_agent/utils/__init__.py +++ b/vision_agent/utils/__init__.py @@ -1,6 +1,5 @@ from .execute import ( CodeInterpreter, - CodeInterpreterFactory, Error, Execution, Logs, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 1907d106..38c38150 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -6,23 +6,15 @@ import re import sys import traceback -import warnings from enum import Enum from pathlib import Path from time import sleep from typing import Any, Dict, Iterable, List, Optional, Union import nbformat -import tenacity from dotenv import load_dotenv -from e2b.exceptions import SandboxException -from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl from e2b_code_interpreter import Execution as E2BExecution from e2b_code_interpreter import Result as E2BResult -from h11._util import LocalProtocolError -from httpx import ConnectError -from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError -from httpx import RemoteProtocolError as HttpxRemoteProtocolError from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -31,11 +23,6 @@ from pydantic import BaseModel, field_serializer from typing_extensions import Self -from vision_agent.utils.exceptions import ( - RemoteSandboxCreationError, - RemoteSandboxExecutionError, -) - load_dotenv() _LOGGER = logging.getLogger(__name__) _SESSION_TIMEOUT = 600 # 10 minutes @@ -434,138 +421,6 @@ def download_file( return Path(local_file_path) -class E2BCodeInterpreter(CodeInterpreter): - def __init__( - self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any - ) -> None: - super().__init__(*args, **kwargs) - assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" - try: - self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl( - *args, **kwargs - ) - except Exception as e: - raise RemoteSandboxCreationError( - f"Failed to create a remote sandbox due to {e}" - ) from e - - result = self.exec_cell( - """ -import platform -import sys -import importlib.metadata - -print(f"Python version: {sys.version}") -print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})") -va_version = importlib.metadata.version("vision-agent") -print(f"Vision Agent version: {va_version}")""" - ) - sys_versions = "\n".join(result.logs.stdout) - _LOGGER.info( - f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}" - ) - self.remote_path = Path( - remote_path if remote_path is not None else "/home/user" - ) - - def close(self, *args: Any, **kwargs: Any) -> None: - try: - self.interpreter.kill(request_timeout=2) - _LOGGER.info( - f"The sandbox {self.interpreter.sandbox_id} is closed successfully." - ) - except Exception as e: - _LOGGER.warn( - f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout." - ) - - def restart_kernel(self) -> None: - self.interpreter.notebook.restart_kernel() - - @tenacity.retry( - wait=tenacity.wait_exponential_jitter(), - stop=tenacity.stop_after_attempt(3), - retry=tenacity.retry_if_exception_type( - ( - LocalProtocolError, - HttpxRemoteProtocolError, - HttpcoreRemoteProtocolError, - ConnectError, - SandboxException, - ) - ), - before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), - after=tenacity.after_log(_LOGGER, logging.INFO), - ) - def exec_cell(self, code: str) -> Execution: - self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox - try: - _LOGGER.info( - f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}" - ) - execution = self.interpreter.notebook.exec_cell( - code=code, - on_stdout=lambda msg: _LOGGER.info(msg), - on_stderr=lambda msg: _LOGGER.info(msg), - ) - _LOGGER.info( - f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}" - ) - return Execution.from_e2b_execution(execution) - except ( - LocalProtocolError, - HttpxRemoteProtocolError, - HttpcoreRemoteProtocolError, - ConnectError, - SandboxException, - ) as e: - raise e - except Exception as e: - raise RemoteSandboxExecutionError( - f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}" - ) from e - - @tenacity.retry( - wait=tenacity.wait_exponential_jitter(), - stop=tenacity.stop_after_attempt(3), - retry=tenacity.retry_if_exception_type( - ( - LocalProtocolError, - HttpxRemoteProtocolError, - HttpcoreRemoteProtocolError, - ConnectError, - SandboxException, - ) - ), - before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), - after=tenacity.after_log(_LOGGER, logging.INFO), - ) - def upload_file(self, file: Union[str, Path]) -> Path: - file_name = Path(file).name - with open(file, "rb") as f: - self.interpreter.files.write(path=str(self.remote_path / file_name), data=f) - _LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}") - return self.remote_path / file_name - - def download_file( - self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path] - ) -> Path: - with open(local_file_path, "w+b") as f: - f.write( - self.interpreter.files.read(path=str(remote_file_path), format="bytes") - ) - _LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}") - return Path(local_file_path) - - @staticmethod - def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore - template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox") - _LOGGER.info( - f"Creating a new E2BCodeInterpreter using template: {template_name}" - ) - return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs) - - class LocalCodeInterpreter(CodeInterpreter): def __init__( self, @@ -667,64 +522,6 @@ def download_file( return Path(local_file_path) -class CodeInterpreterFactory: - """Factory class for creating code interpreters. - Could be extended to support multiple code interpreters. - """ - - _instance_map: Dict[str, CodeInterpreter] = {} - _default_key = "default" - - @staticmethod - def get_default_instance() -> CodeInterpreter: - warnings.warn( - "Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future." - ) - inst_map = CodeInterpreterFactory._instance_map - instance = inst_map.get(CodeInterpreterFactory._default_key) - if instance: - return instance - instance = CodeInterpreterFactory.new_instance() - inst_map[CodeInterpreterFactory._default_key] = instance - return instance - - @staticmethod - def new_instance( - code_sandbox_runtime: Optional[str] = None, - remote_path: Optional[Union[str, Path]] = None, - ) -> CodeInterpreter: - if not code_sandbox_runtime: - code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local") - if code_sandbox_runtime == "e2b": - envs = _get_e2b_env() - instance: CodeInterpreter = E2BCodeInterpreter( - timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs - ) - elif code_sandbox_runtime == "local": - instance = LocalCodeInterpreter( - timeout=_SESSION_TIMEOUT, remote_path=remote_path - ) - else: - raise ValueError( - f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local" - ) - return instance - - -def _get_e2b_env() -> Union[Dict[str, str], None]: - openai_api_key = os.getenv("OPENAI_API_KEY", "") - anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") - if openai_api_key or anthropic_api_key: - envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")} - if openai_api_key: - envs["OPENAI_API_KEY"] = openai_api_key - if anthropic_api_key: - envs["ANTHROPIC_API_KEY"] = anthropic_api_key - else: - envs = None - return envs - - def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution: """Parse notebook cell outputs to Execution object. Output types: https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs From b0104ffd3fddfb2c705b5aded7227fa102e9976a Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 21:53:25 +0800 Subject: [PATCH 03/12] remove e2b --- vision_agent/utils/execute.py | 39 ----------------------------------- 1 file changed, 39 deletions(-) diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 38c38150..17b9d3f0 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -13,8 +13,6 @@ import nbformat from dotenv import load_dotenv -from e2b_code_interpreter import Execution as E2BExecution -from e2b_code_interpreter import Result as E2BResult from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -200,23 +198,6 @@ def formats(self) -> Iterable[str]: formats.extend(iter(self.extra)) return formats - @staticmethod - def from_e2b_result(result: E2BResult) -> "Result": - """ - Creates a Result object from an E2BResult object. - """ - data = { - MimeType.TEXT_PLAIN.value: result.text, - MimeType.IMAGE_PNG.value: result.png, - MimeType.APPLICATION_JSON.value: result.json, - } - for k, v in result.extra.items(): - data[k] = v - return Result( - is_main_result=result.is_main_result, - data=data, - ) - class Logs(BaseModel): """Data printed to stdout and stderr during execution, usually by print statements, @@ -357,26 +338,6 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution": ) ) - @staticmethod - def from_e2b_execution(exec: E2BExecution) -> "Execution": - """Creates an Execution object from an E2BResult object.""" - return Execution( - results=[Result.from_e2b_result(res) for res in exec.results], - logs=Logs(stdout=exec.logs.stdout, stderr=exec.logs.stderr), - error=( - Error( - name=exec.error.name, - value=_remove_escape_and_color_codes(exec.error.value), - traceback_raw=[ - _remove_escape_and_color_codes(line) - for line in exec.error.traceback.split("\n") - ], - ) - if exec.error - else None - ), - ) - class CodeInterpreter(abc.ABC): """Code interpreter interface.""" From b93d62fa046317e302522bc1ad97ecdac4a72948 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 21:55:10 +0800 Subject: [PATCH 04/12] remove e2b from pyproject --- pyproject.toml | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ca6726a7..90e447a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,14 +8,14 @@ version = "0.2.160" description = "Toolset for Vision Agent" authors = ["Landing AI "] readme = "README.md" -packages = [{include = "vision_agent"}] +packages = [{ include = "vision_agent" }] [tool.poetry.urls] "Homepage" = "https://landing.ai" "repository" = "https://github.com/landing-ai/vision-agent" "documentation" = "https://github.com/landing-ai/vision-agent" -[tool.poetry.dependencies] # main dependency group +[tool.poetry.dependencies] # main dependency group python = ">=3.9,<4.0" numpy = ">=1.21.0,<2.0.0" @@ -35,8 +35,6 @@ nbformat = "^5.10.4" rich = "^13.7.1" langsmith = "^0.1.58" ipykernel = "^6.29.4" -e2b = "^0.17.2a50" -e2b-code-interpreter = "0.0.11a37" tenacity = "^8.3.0" pillow-heif = "^0.16.0" pytube = "15.0.0" @@ -58,7 +56,7 @@ types-tqdm = "^4.65.0.1" setuptools = "^68.0.0" griffe = "^0.45.3" mkdocs = "^1.5.3" -mkdocstrings = {extras = ["python"], version = "^0.23.0"} +mkdocstrings = { extras = ["python"], version = "^0.23.0" } mkdocs-material = "^9.4.2" types-tabulate = "^0.9.0.20240106" scikit-image = "<0.23.1" @@ -72,7 +70,7 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.black] exclude = '.vscode|.eggs|venv' -line-length = 88 # suggested by black official site +line-length = 88 # suggested by black official site [tool.isort] line_length = 88 @@ -98,10 +96,4 @@ show_error_codes = true [[tool.mypy.overrides]] ignore_missing_imports = true -module = [ - "cv2.*", - "openai.*", - "sentence_transformers.*", - "e2b_code_interpreter.*", - "e2b.*" -] +module = ["cv2.*", "openai.*", "sentence_transformers.*"] From 1e786ff876c63784dc4bb6d1d41d19b446336f77 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 22:48:42 +0800 Subject: [PATCH 05/12] remove tenacity --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 90e447a1..7271ae83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ nbformat = "^5.10.4" rich = "^13.7.1" langsmith = "^0.1.58" ipykernel = "^6.29.4" -tenacity = "^8.3.0" pillow-heif = "^0.16.0" pytube = "15.0.0" anthropic = "^0.31.0" From 7f4fc3b4a17c5228c4def91cb4e259909a2b1468 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Thu, 10 Oct 2024 22:59:42 +0800 Subject: [PATCH 06/12] add VERBOSITY env var --- vision_agent/tools/meta_tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 024d4230..c9fc7be0 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -28,6 +28,7 @@ CURRENT_LINE = 0 DEFAULT_WINDOW_SIZE = 100 ZMQ_PORT = os.environ.get("ZMQ_PORT", None) +VERBOSITY = os.environ.get("VERBOSITY", 0) def report_progress_callback(port: int, inp: Dict[str, Any]) -> None: @@ -375,7 +376,7 @@ def detect_dogs(image_path: str): ) ) else: - agent = va.agent.VisionAgentCoder() + agent = va.agent.VisionAgentCoder(verbosity=int(VERBOSITY)) fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}] response = agent.chat_with_workflow( @@ -438,7 +439,7 @@ def detect_dogs(image_path: str): return dogs """ - agent = va.agent.VisionAgentCoder() + agent = va.agent.VisionAgentCoder(verbosity=int(VERBOSITY)) if name not in artifacts: print(f"[Artifact {name} does not exist]") return f"[Artifact {name} does not exist]" From 3844264d456a97e6219813f3bd9d090d336c9956 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 08:38:53 +0800 Subject: [PATCH 07/12] Revert "remove tenacity" This reverts commit 1e786ff876c63784dc4bb6d1d41d19b446336f77. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7271ae83..90e447a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ nbformat = "^5.10.4" rich = "^13.7.1" langsmith = "^0.1.58" ipykernel = "^6.29.4" +tenacity = "^8.3.0" pillow-heif = "^0.16.0" pytube = "15.0.0" anthropic = "^0.31.0" From a422041ce50cb9f1f95390184306acb4fb51ac58 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 08:39:02 +0800 Subject: [PATCH 08/12] Revert "remove e2b from pyproject" This reverts commit b93d62fa046317e302522bc1ad97ecdac4a72948. --- pyproject.toml | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90e447a1..ca6726a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,14 +8,14 @@ version = "0.2.160" description = "Toolset for Vision Agent" authors = ["Landing AI "] readme = "README.md" -packages = [{ include = "vision_agent" }] +packages = [{include = "vision_agent"}] [tool.poetry.urls] "Homepage" = "https://landing.ai" "repository" = "https://github.com/landing-ai/vision-agent" "documentation" = "https://github.com/landing-ai/vision-agent" -[tool.poetry.dependencies] # main dependency group +[tool.poetry.dependencies] # main dependency group python = ">=3.9,<4.0" numpy = ">=1.21.0,<2.0.0" @@ -35,6 +35,8 @@ nbformat = "^5.10.4" rich = "^13.7.1" langsmith = "^0.1.58" ipykernel = "^6.29.4" +e2b = "^0.17.2a50" +e2b-code-interpreter = "0.0.11a37" tenacity = "^8.3.0" pillow-heif = "^0.16.0" pytube = "15.0.0" @@ -56,7 +58,7 @@ types-tqdm = "^4.65.0.1" setuptools = "^68.0.0" griffe = "^0.45.3" mkdocs = "^1.5.3" -mkdocstrings = { extras = ["python"], version = "^0.23.0" } +mkdocstrings = {extras = ["python"], version = "^0.23.0"} mkdocs-material = "^9.4.2" types-tabulate = "^0.9.0.20240106" scikit-image = "<0.23.1" @@ -70,7 +72,7 @@ log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.black] exclude = '.vscode|.eggs|venv' -line-length = 88 # suggested by black official site +line-length = 88 # suggested by black official site [tool.isort] line_length = 88 @@ -96,4 +98,10 @@ show_error_codes = true [[tool.mypy.overrides]] ignore_missing_imports = true -module = ["cv2.*", "openai.*", "sentence_transformers.*"] +module = [ + "cv2.*", + "openai.*", + "sentence_transformers.*", + "e2b_code_interpreter.*", + "e2b.*" +] From c6ef4adb016a7380a62e108ae6c916c7c3294072 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 08:39:11 +0800 Subject: [PATCH 09/12] Revert "remove e2b" This reverts commit b0104ffd3fddfb2c705b5aded7227fa102e9976a. --- vision_agent/utils/execute.py | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 17b9d3f0..38c38150 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -13,6 +13,8 @@ import nbformat from dotenv import load_dotenv +from e2b_code_interpreter import Execution as E2BExecution +from e2b_code_interpreter import Result as E2BResult from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -198,6 +200,23 @@ def formats(self) -> Iterable[str]: formats.extend(iter(self.extra)) return formats + @staticmethod + def from_e2b_result(result: E2BResult) -> "Result": + """ + Creates a Result object from an E2BResult object. + """ + data = { + MimeType.TEXT_PLAIN.value: result.text, + MimeType.IMAGE_PNG.value: result.png, + MimeType.APPLICATION_JSON.value: result.json, + } + for k, v in result.extra.items(): + data[k] = v + return Result( + is_main_result=result.is_main_result, + data=data, + ) + class Logs(BaseModel): """Data printed to stdout and stderr during execution, usually by print statements, @@ -338,6 +357,26 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution": ) ) + @staticmethod + def from_e2b_execution(exec: E2BExecution) -> "Execution": + """Creates an Execution object from an E2BResult object.""" + return Execution( + results=[Result.from_e2b_result(res) for res in exec.results], + logs=Logs(stdout=exec.logs.stdout, stderr=exec.logs.stderr), + error=( + Error( + name=exec.error.name, + value=_remove_escape_and_color_codes(exec.error.value), + traceback_raw=[ + _remove_escape_and_color_codes(line) + for line in exec.error.traceback.split("\n") + ], + ) + if exec.error + else None + ), + ) + class CodeInterpreter(abc.ABC): """Code interpreter interface.""" From 3ee619837438aaa2b01d6b13fbe140eab77422aa Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 08:39:36 +0800 Subject: [PATCH 10/12] Revert "remove e2bcodeinterpreter and codefactory" This reverts commit c67e2250d7da93ab4dda6409c2aae65d4999652a. --- vision_agent/agent/vision_agent.py | 1 + vision_agent/agent/vision_agent_coder.py | 26 ++- vision_agent/utils/__init__.py | 1 + vision_agent/utils/execute.py | 203 +++++++++++++++++++++++ 4 files changed, 217 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 98e70fa8..ef49086a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -20,6 +20,7 @@ check_and_load_image, use_extra_vision_agent_args, ) +from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution logging.basicConfig(level=logging.INFO) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index e8dc2247..1e5030a2 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -38,8 +38,8 @@ OpenAILMM, ) from vision_agent.tools.meta_tools import get_diff -from vision_agent.utils import Execution -from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter +from vision_agent.utils import CodeInterpreterFactory, Execution +from vision_agent.utils.execute import CodeInterpreter from vision_agent.utils.image_utils import b64_to_pil from vision_agent.utils.sim import AzureSim, OllamaSim, Sim from vision_agent.utils.video import play_video @@ -49,7 +49,6 @@ _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 _CONSOLE = Console() -_SESSION_TIMEOUT = 600 # 10 minutes class DefaultImports: @@ -624,7 +623,7 @@ def __init__( tool_recommender: Optional[Sim] = None, verbosity: int = 0, report_progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, ) -> None: """Initialize the Vision Agent Coder. @@ -642,10 +641,11 @@ def __init__( in a web application where multiple VisionAgentCoder instances are running in parallel. This callback ensures that the progress are not mixed up. - code_interpreter (Optional[CodeInterpreter]): the code interpreter to use. A - code interpreter is used to run the generated code. It can be one of the - following values: None, LocalCodeInterpreter or E2BCodeInterpreter. - If None, LocalCodeInterpreter, which is the local python runtime environment will be used. + code_sandbox_runtime (Optional[str]): the code sandbox runtime to use. A + code sandbox is used to run the generated code. It can be one of the + following values: None, "local" or "e2b". If None, VisionAgentCoder + will read the value from the environment variable CODE_SANDBOX_RUNTIME. + If it's also None, the local python runtime environment will be used. """ self.planner = AnthropicLMM(temperature=0.0) if planner is None else planner @@ -662,11 +662,7 @@ def __init__( else tool_recommender ) self.report_progress_callback = report_progress_callback - self.code_interpreter = ( - code_interpreter - if code_interpreter is not None - else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) - ) + self.code_sandbox_runtime = code_sandbox_runtime def __call__( self, @@ -727,7 +723,9 @@ def chat_with_workflow( raise ValueError("Chat cannot be empty.") # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues - with self.code_interpreter as code_interpreter: + with CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime + ) as code_interpreter: chat = copy.deepcopy(chat) media_list = [] for chat_i in chat: diff --git a/vision_agent/utils/__init__.py b/vision_agent/utils/__init__.py index 819477c1..2810713a 100644 --- a/vision_agent/utils/__init__.py +++ b/vision_agent/utils/__init__.py @@ -1,5 +1,6 @@ from .execute import ( CodeInterpreter, + CodeInterpreterFactory, Error, Execution, Logs, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 38c38150..1907d106 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -6,15 +6,23 @@ import re import sys import traceback +import warnings from enum import Enum from pathlib import Path from time import sleep from typing import Any, Dict, Iterable, List, Optional, Union import nbformat +import tenacity from dotenv import load_dotenv +from e2b.exceptions import SandboxException +from e2b_code_interpreter import CodeInterpreter as E2BCodeInterpreterImpl from e2b_code_interpreter import Execution as E2BExecution from e2b_code_interpreter import Result as E2BResult +from h11._util import LocalProtocolError +from httpx import ConnectError +from httpx import RemoteProtocolError as HttpcoreRemoteProtocolError +from httpx import RemoteProtocolError as HttpxRemoteProtocolError from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -23,6 +31,11 @@ from pydantic import BaseModel, field_serializer from typing_extensions import Self +from vision_agent.utils.exceptions import ( + RemoteSandboxCreationError, + RemoteSandboxExecutionError, +) + load_dotenv() _LOGGER = logging.getLogger(__name__) _SESSION_TIMEOUT = 600 # 10 minutes @@ -421,6 +434,138 @@ def download_file( return Path(local_file_path) +class E2BCodeInterpreter(CodeInterpreter): + def __init__( + self, remote_path: Optional[Union[str, Path]] = None, *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" + try: + self.interpreter = E2BCodeInterpreter._new_e2b_interpreter_impl( + *args, **kwargs + ) + except Exception as e: + raise RemoteSandboxCreationError( + f"Failed to create a remote sandbox due to {e}" + ) from e + + result = self.exec_cell( + """ +import platform +import sys +import importlib.metadata + +print(f"Python version: {sys.version}") +print(f"OS version: {platform.system()} {platform.release()} ({platform.architecture()})") +va_version = importlib.metadata.version("vision-agent") +print(f"Vision Agent version: {va_version}")""" + ) + sys_versions = "\n".join(result.logs.stdout) + _LOGGER.info( + f"E2BCodeInterpreter (sandbox id: {self.interpreter.sandbox_id}) initialized:\n{sys_versions}" + ) + self.remote_path = Path( + remote_path if remote_path is not None else "/home/user" + ) + + def close(self, *args: Any, **kwargs: Any) -> None: + try: + self.interpreter.kill(request_timeout=2) + _LOGGER.info( + f"The sandbox {self.interpreter.sandbox_id} is closed successfully." + ) + except Exception as e: + _LOGGER.warn( + f"Failed to close the remote sandbox ({self.interpreter.sandbox_id}) due to {e}. This is not an issue. It's likely that the sandbox is already closed due to timeout." + ) + + def restart_kernel(self) -> None: + self.interpreter.notebook.restart_kernel() + + @tenacity.retry( + wait=tenacity.wait_exponential_jitter(), + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type( + ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) + ), + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), + after=tenacity.after_log(_LOGGER, logging.INFO), + ) + def exec_cell(self, code: str) -> Execution: + self.interpreter.set_timeout(_SESSION_TIMEOUT) # Extend the life of the sandbox + try: + _LOGGER.info( + f"Start code execution in remote sandbox {self.interpreter.sandbox_id}. Timeout: {_SESSION_TIMEOUT}. Code hash: {hash(code)}" + ) + execution = self.interpreter.notebook.exec_cell( + code=code, + on_stdout=lambda msg: _LOGGER.info(msg), + on_stderr=lambda msg: _LOGGER.info(msg), + ) + _LOGGER.info( + f"Finished code execution in remote sandbox {self.interpreter.sandbox_id}. Code hash: {hash(code)}" + ) + return Execution.from_e2b_execution(execution) + except ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) as e: + raise e + except Exception as e: + raise RemoteSandboxExecutionError( + f"Failed executing code in remote sandbox ({self.interpreter.sandbox_id}) due to error '{type(e).__name__} {str(e)}', code: {code}" + ) from e + + @tenacity.retry( + wait=tenacity.wait_exponential_jitter(), + stop=tenacity.stop_after_attempt(3), + retry=tenacity.retry_if_exception_type( + ( + LocalProtocolError, + HttpxRemoteProtocolError, + HttpcoreRemoteProtocolError, + ConnectError, + SandboxException, + ) + ), + before_sleep=tenacity.before_sleep_log(_LOGGER, logging.INFO), + after=tenacity.after_log(_LOGGER, logging.INFO), + ) + def upload_file(self, file: Union[str, Path]) -> Path: + file_name = Path(file).name + with open(file, "rb") as f: + self.interpreter.files.write(path=str(self.remote_path / file_name), data=f) + _LOGGER.info(f"File ({file}) is uploaded to: {str(self.remote_path)}") + return self.remote_path / file_name + + def download_file( + self, remote_file_path: Union[str, Path], local_file_path: Union[str, Path] + ) -> Path: + with open(local_file_path, "w+b") as f: + f.write( + self.interpreter.files.read(path=str(remote_file_path), format="bytes") + ) + _LOGGER.info(f"File ({remote_file_path}) is downloaded to: {local_file_path}") + return Path(local_file_path) + + @staticmethod + def _new_e2b_interpreter_impl(*args, **kwargs) -> E2BCodeInterpreterImpl: # type: ignore + template_name = os.environ.get("E2B_TEMPLATE_NAME", "va-sandbox") + _LOGGER.info( + f"Creating a new E2BCodeInterpreter using template: {template_name}" + ) + return E2BCodeInterpreterImpl(template=template_name, *args, **kwargs) + + class LocalCodeInterpreter(CodeInterpreter): def __init__( self, @@ -522,6 +667,64 @@ def download_file( return Path(local_file_path) +class CodeInterpreterFactory: + """Factory class for creating code interpreters. + Could be extended to support multiple code interpreters. + """ + + _instance_map: Dict[str, CodeInterpreter] = {} + _default_key = "default" + + @staticmethod + def get_default_instance() -> CodeInterpreter: + warnings.warn( + "Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future." + ) + inst_map = CodeInterpreterFactory._instance_map + instance = inst_map.get(CodeInterpreterFactory._default_key) + if instance: + return instance + instance = CodeInterpreterFactory.new_instance() + inst_map[CodeInterpreterFactory._default_key] = instance + return instance + + @staticmethod + def new_instance( + code_sandbox_runtime: Optional[str] = None, + remote_path: Optional[Union[str, Path]] = None, + ) -> CodeInterpreter: + if not code_sandbox_runtime: + code_sandbox_runtime = os.getenv("CODE_SANDBOX_RUNTIME", "local") + if code_sandbox_runtime == "e2b": + envs = _get_e2b_env() + instance: CodeInterpreter = E2BCodeInterpreter( + timeout=_SESSION_TIMEOUT, remote_path=remote_path, envs=envs + ) + elif code_sandbox_runtime == "local": + instance = LocalCodeInterpreter( + timeout=_SESSION_TIMEOUT, remote_path=remote_path + ) + else: + raise ValueError( + f"Unsupported code sandbox runtime: {code_sandbox_runtime}. Supported runtimes: e2b, local" + ) + return instance + + +def _get_e2b_env() -> Union[Dict[str, str], None]: + openai_api_key = os.getenv("OPENAI_API_KEY", "") + anthropic_api_key = os.getenv("ANTHROPIC_API_KEY", "") + if openai_api_key or anthropic_api_key: + envs = {"WORKSPACE": os.getenv("WORKSPACE", "/home/user")} + if openai_api_key: + envs["OPENAI_API_KEY"] = openai_api_key + if anthropic_api_key: + envs["ANTHROPIC_API_KEY"] = anthropic_api_key + else: + envs = None + return envs + + def _parse_local_code_interpreter_outputs(outputs: List[Dict[str, Any]]) -> Execution: """Parse notebook cell outputs to Execution object. Output types: https://nbformat.readthedocs.io/en/latest/format_description.html#code-cell-outputs From 393bb4edb680870a4c338fba86e160079db00255 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 08:55:33 +0800 Subject: [PATCH 11/12] Revert "calling" This reverts commit c66a4639b39951cf7af1d8b8e6d6cf05754ca0d2. --- vision_agent/agent/vision_agent.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ef49086a..ba6e1d64 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -21,7 +21,7 @@ use_extra_vision_agent_args, ) from vision_agent.utils import CodeInterpreterFactory -from vision_agent.utils.execute import CodeInterpreter, LocalCodeInterpreter, Execution +from vision_agent.utils.execute import CodeInterpreter, Execution logging.basicConfig(level=logging.INFO) _LOGGER = logging.getLogger(__name__) @@ -29,7 +29,6 @@ WORKSPACE.mkdir(parents=True, exist_ok=True) if str(WORKSPACE) != "": os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}" -_SESSION_TIMEOUT = 600 # 10 minutes class BoilerplateCode: @@ -196,7 +195,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -207,17 +206,13 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity - self.code_interpreter = ( - code_interpreter - if code_interpreter is not None - else LocalCodeInterpreter(timeout=_SESSION_TIMEOUT) - ) + self.code_sandbox_runtime = code_sandbox_runtime self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) @@ -289,7 +284,9 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") - with self.code_interpreter as code_interpreter: + with CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime, + ) as code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) last_user_message = chat[-1] @@ -475,7 +472,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using OpenAI LMMs. @@ -486,7 +483,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ agent = OpenAILMM(temperature=0.0, json_mode=True) if agent is None else agent @@ -494,7 +491,7 @@ def __init__( agent, verbosity, local_artifacts_path, - code_interpreter, + code_sandbox_runtime, callback_message, ) @@ -505,7 +502,7 @@ def __init__( agent: Optional[LMM] = None, verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, - code_interpreter: Optional[CodeInterpreter] = None, + code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent using Anthropic LMMs. @@ -516,7 +513,7 @@ def __init__( verbosity (int): The verbosity level of the agent. local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. - code_interpreter (Optional[CodeInterpreter]): The code interpreter to use, default to local python + code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. """ agent = AnthropicLMM(temperature=0.0) if agent is None else agent @@ -524,6 +521,6 @@ def __init__( agent, verbosity, local_artifacts_path, - code_interpreter, + code_sandbox_runtime, callback_message, ) From 02ee6e42fc32ee8f61dc301a93ecec4bdc689810 Mon Sep 17 00:00:00 2001 From: Zhichao Date: Fri, 11 Oct 2024 09:05:55 +0800 Subject: [PATCH 12/12] add an optional code_interpreter --- vision_agent/agent/vision_agent.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index ba6e1d64..1a38468f 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -197,6 +197,7 @@ def __init__( local_artifacts_path: Optional[Union[str, Path]] = None, code_sandbox_runtime: Optional[str] = None, callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, + code_interpreter: Optional[CodeInterpreter] = None, ) -> None: """Initialize the VisionAgent. @@ -207,12 +208,14 @@ def __init__( local_artifacts_path (Optional[Union[str, Path]]): The path to the local artifacts file. code_sandbox_runtime (Optional[str]): The code sandbox runtime to use. + code_interpreter (Optional[CodeInterpreter]): if not None, use this CodeInterpreter """ self.agent = AnthropicLMM(temperature=0.0) if agent is None else agent self.max_iterations = 12 self.verbosity = verbosity self.code_sandbox_runtime = code_sandbox_runtime + self.code_interpreter = code_interpreter self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) @@ -284,9 +287,14 @@ def chat_with_code( # this is setting remote artifacts path artifacts = Artifacts(WORKSPACE / "artifacts.pkl") - with CodeInterpreterFactory.new_instance( - code_sandbox_runtime=self.code_sandbox_runtime, - ) as code_interpreter: + code_interpreter = ( + self.code_interpreter + if self.code_interpreter is not None + else CodeInterpreterFactory.new_instance( + code_sandbox_runtime=self.code_sandbox_runtime, + ) + ) + with code_interpreter: orig_chat = copy.deepcopy(chat) int_chat = copy.deepcopy(chat) last_user_message = chat[-1]