diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 17b9d3f0..38c38150 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -13,6 +13,8 @@ import nbformat from dotenv import load_dotenv +from e2b_code_interpreter import Execution as E2BExecution +from e2b_code_interpreter import Result as E2BResult from nbclient import NotebookClient from nbclient import __version__ as nbclient_version from nbclient.exceptions import CellTimeoutError, DeadKernelError @@ -198,6 +200,23 @@ def formats(self) -> Iterable[str]: formats.extend(iter(self.extra)) return formats + @staticmethod + def from_e2b_result(result: E2BResult) -> "Result": + """ + Creates a Result object from an E2BResult object. + """ + data = { + MimeType.TEXT_PLAIN.value: result.text, + MimeType.IMAGE_PNG.value: result.png, + MimeType.APPLICATION_JSON.value: result.json, + } + for k, v in result.extra.items(): + data[k] = v + return Result( + is_main_result=result.is_main_result, + data=data, + ) + class Logs(BaseModel): """Data printed to stdout and stderr during execution, usually by print statements, @@ -338,6 +357,26 @@ def from_exception(exec: Exception, traceback_raw: List[str]) -> "Execution": ) ) + @staticmethod + def from_e2b_execution(exec: E2BExecution) -> "Execution": + """Creates an Execution object from an E2BResult object.""" + return Execution( + results=[Result.from_e2b_result(res) for res in exec.results], + logs=Logs(stdout=exec.logs.stdout, stderr=exec.logs.stderr), + error=( + Error( + name=exec.error.name, + value=_remove_escape_and_color_codes(exec.error.value), + traceback_raw=[ + _remove_escape_and_color_codes(line) + for line in exec.error.traceback.split("\n") + ], + ) + if exec.error + else None + ), + ) + class CodeInterpreter(abc.ABC): """Code interpreter interface."""