From 32e0156cb4b45d9d94e11b9f6a0bd407c4d5d7f1 Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Tue, 4 Jun 2024 11:20:57 -0700 Subject: [PATCH] Better code sandbox lifecycle management (#110) * Support local image visualization; Better code sandbox lifecycle management * Fix lint errors * Fix format * Print logs from execution result * Better format for execution outputs * More consistent format for execution outputs --- vision_agent/agent/vision_agent.py | 238 +++++++++++++++-------------- vision_agent/tools/tools.py | 12 +- vision_agent/utils/execute.py | 41 ++--- 3 files changed, 156 insertions(+), 135 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 647703d5..254949b9 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -25,12 +25,13 @@ from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.utils import CodeInterpreterFactory, Execution +from vision_agent.utils.execute import CodeInterpreter +from vision_agent.utils.image_utils import b64_to_pil from vision_agent.utils.sim import Sim logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 -_EXECUTE = CodeInterpreterFactory.get_default_instance() _CONSOLE = Console() _DEFAULT_IMPORT = "\n".join(T.__new_tools__) @@ -122,6 +123,7 @@ def write_and_test_code( coder: LLM, tester: LLM, debugger: LLM, + code_interpreter: CodeInterpreter, log_progress: Callable[[Dict[str, Any]], None], verbosity: int = 0, max_retries: int = 3, @@ -158,7 +160,7 @@ def write_and_test_code( }, } ) - result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") + result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") log_progress( { "type": "code", @@ -173,7 +175,7 @@ def write_and_test_code( if verbosity == 2: _print_code("Initial code and tests:", code, test) _LOGGER.info( - f"Initial code execution result:\n{result.text(include_logs=False)}" + f"Initial code execution result:\n{result.text(include_logs=True)}" ) count = 0 @@ -210,7 +212,7 @@ def write_and_test_code( {"code": f"{code}\n{test}", "feedback": fixed_code_and_test["reflections"]} ) - result = _EXECUTE.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") + result = code_interpreter.exec_isolation(f"{_DEFAULT_IMPORT}\n{code}\n{test}") log_progress( { "type": "code", @@ -228,7 +230,7 @@ def write_and_test_code( ) _print_code("Code and test after attempted fix:", code, test) _LOGGER.info( - f"Code execution result after attempted fix: {result.text(include_logs=False)}" + f"Code execution result after attempted fix: {result.text(include_logs=True)}" ) count += 1 @@ -377,6 +379,7 @@ def chat_with_workflow( chat: List[Dict[str, str]], media: Optional[Union[str, Path]] = None, self_reflection: bool = False, + display_visualization: bool = False, ) -> Dict[str, Any]: """Chat with Vision Agent and return intermediate information regarding the task. @@ -385,6 +388,7 @@ def chat_with_workflow( [{"role": "user", "content": "describe your task here..."}]. media (Optional[Union[str, Path]]): The media file to be used in the task. self_reflection (bool): Whether to reflect on the task and debug the code. + show_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). Returns: Dict[str, Any]: A dictionary containing the code, test, test result, plan, @@ -394,127 +398,137 @@ def chat_with_workflow( if not chat: raise ValueError("Chat cannot be empty.") - if media is not None: - media = _EXECUTE.upload_file(media) - for chat_i in chat: - if chat_i["role"] == "user": - chat_i["content"] += f" Image name {media}" - - # re-grab custom tools - global _DEFAULT_IMPORT - _DEFAULT_IMPORT = "\n".join(T.__new_tools__) - - code = "" - test = "" - working_memory: List[Dict[str, str]] = [] - results = {"code": "", "test": "", "plan": []} - plan = [] - success = False - retries = 0 - - while not success and retries < self.max_retries: - self.log_progress( - { - "type": "plans", - "status": "started", - } - ) - plan_i = write_plan( - chat, - T.TOOL_DESCRIPTIONS, - format_memory(working_memory), - self.planner, - media=[media] if media else None, - ) - plan_i_str = "\n-".join([e["instructions"] for e in plan_i]) - - self.log_progress( - { - "type": "plans", - "status": "completed", - "payload": plan_i, - } - ) - if self.verbosity >= 1: - - _LOGGER.info( - f""" -{tabulate(tabular_data=plan_i, headers="keys", tablefmt="mixed_grid", maxcolwidths=_MAX_TABULATE_COL_WIDTH)}""" - ) - - tool_info = retrieve_tools( - plan_i, - self.tool_recommender, - self.log_progress, - self.verbosity, - ) - results = write_and_test_code( - FULL_TASK.format(user_request=chat[0]["content"], subtasks=plan_i_str), - tool_info, - T.UTILITIES_DOCSTRING, - format_memory(working_memory), - self.coder, - self.tester, - self.debugger, - self.log_progress, - verbosity=self.verbosity, - input_media=media, - ) - success = cast(bool, results["success"]) - code = cast(str, results["code"]) - test = cast(str, results["test"]) - working_memory.extend(results["working_memory"]) # type: ignore - plan.append({"code": code, "test": test, "plan": plan_i}) - - if self_reflection: + # NOTE: each chat should have a dedicated code interpreter instance to avoid concurrency issues + with CodeInterpreterFactory.new_instance() as code_interpreter: + if media is not None: + media = code_interpreter.upload_file(media) + for chat_i in chat: + if chat_i["role"] == "user": + chat_i["content"] += f" Image name {media}" + + # re-grab custom tools + global _DEFAULT_IMPORT + _DEFAULT_IMPORT = "\n".join(T.__new_tools__) + + code = "" + test = "" + working_memory: List[Dict[str, str]] = [] + results = {"code": "", "test": "", "plan": []} + plan = [] + success = False + retries = 0 + + while not success and retries < self.max_retries: self.log_progress( { - "type": "self_reflection", + "type": "plans", "status": "started", } ) - reflection = reflect( + plan_i = write_plan( chat, - FULL_TASK.format( - user_request=chat[0]["content"], subtasks=plan_i_str - ), - code, + T.TOOL_DESCRIPTIONS, + format_memory(working_memory), self.planner, + media=[media] if media else None, ) - if self.verbosity > 0: - _LOGGER.info(f"Reflection: {reflection}") - feedback = cast(str, reflection["feedback"]) - success = cast(bool, reflection["success"]) + plan_i_str = "\n-".join([e["instructions"] for e in plan_i]) + self.log_progress( { - "type": "self_reflection", - "status": "completed" if success else "failed", - "payload": reflection, + "type": "plans", + "status": "completed", + "payload": plan_i, } ) - working_memory.append({"code": f"{code}\n{test}", "feedback": feedback}) - - retries += 1 + if self.verbosity >= 1: + _LOGGER.info( + f"\n{tabulate(tabular_data=plan_i, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + ) + + tool_info = retrieve_tools( + plan_i, + self.tool_recommender, + self.log_progress, + self.verbosity, + ) + results = write_and_test_code( + task=FULL_TASK.format( + user_request=chat[0]["content"], subtasks=plan_i_str + ), + tool_info=tool_info, + tool_utils=T.UTILITIES_DOCSTRING, + working_memory=format_memory(working_memory), + coder=self.coder, + tester=self.tester, + debugger=self.debugger, + code_interpreter=code_interpreter, + log_progress=self.log_progress, + verbosity=self.verbosity, + input_media=media, + ) + success = cast(bool, results["success"]) + code = cast(str, results["code"]) + test = cast(str, results["test"]) + working_memory.extend(results["working_memory"]) # type: ignore + plan.append({"code": code, "test": test, "plan": plan_i}) + + if self_reflection: + self.log_progress( + { + "type": "self_reflection", + "status": "started", + } + ) + reflection = reflect( + chat, + FULL_TASK.format( + user_request=chat[0]["content"], subtasks=plan_i_str + ), + code, + self.planner, + ) + if self.verbosity > 0: + _LOGGER.info(f"Reflection: {reflection}") + feedback = cast(str, reflection["feedback"]) + success = cast(bool, reflection["success"]) + self.log_progress( + { + "type": "self_reflection", + "status": "completed" if success else "failed", + "payload": reflection, + } + ) + working_memory.append( + {"code": f"{code}\n{test}", "feedback": feedback} + ) + + retries += 1 + + execution_result = cast(Execution, results["test_result"]) + self.log_progress( + { + "type": "final_code", + "status": "completed" if success else "failed", + "payload": { + "code": code, + "test": test, + "result": execution_result.to_json(), + }, + } + ) - self.log_progress( - { - "type": "final_code", - "status": "completed" if success else "failed", - "payload": { - "code": code, - "test": test, - "result": cast(Execution, results["test_result"]).to_json(), - }, + if display_visualization: + for res in execution_result.results: + if res.png: + b64_to_pil(res.png).show() + return { + "code": code, + "test": test, + "test_result": execution_result, + "plan": plan, + "working_memory": working_memory, } - ) - - return { - "code": code, - "test": test, - "test_result": results["test_result"], - "plan": plan, - "working_memory": working_memory, - } def log_progress(self, data: Dict[str, Any]) -> None: if self.report_progress_callback is not None: diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 7baa02b6..9f44e536 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -524,7 +524,7 @@ def save_image(image: np.ndarray) -> str: def overlay_bounding_boxes( image: np.ndarray, bboxes: List[Dict[str, Any]] ) -> np.ndarray: - """'display_bounding_boxes' is a utility function that displays bounding boxes on + """'overlay_bounding_boxes' is a utility function that displays bounding boxes on an image. Parameters: @@ -537,7 +537,7 @@ def overlay_bounding_boxes( Example ------- - >>> image_with_bboxes = display_bounding_boxes( + >>> image_with_bboxes = overlay_bounding_boxes( image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}], ) """ @@ -583,7 +583,7 @@ def overlay_bounding_boxes( def overlay_segmentation_masks( image: np.ndarray, masks: List[Dict[str, Any]] ) -> np.ndarray: - """'display_segmentation_masks' is a utility function that displays segmentation + """'overlay_segmentation_masks' is a utility function that displays segmentation masks. Parameters: @@ -595,7 +595,7 @@ def overlay_segmentation_masks( Example ------- - >>> image_with_masks = display_segmentation_masks( + >>> image_with_masks = overlay_segmentation_masks( image, [{ 'score': 0.99, @@ -633,7 +633,7 @@ def overlay_segmentation_masks( def overlay_heat_map( image: np.ndarray, heat_map: Dict[str, Any], alpha: float = 0.8 ) -> np.ndarray: - """'display_heat_map' is a utility function that displays a heat map on an image. + """'overlay_heat_map' is a utility function that displays a heat map on an image. Parameters: image (np.ndarray): The image to display the heat map on. @@ -646,7 +646,7 @@ def overlay_heat_map( Example ------- - >>> image_with_heat_map = display_heat_map( + >>> image_with_heat_map = overlay_heat_map( image, { 'heat_map': array([[0, 0, 0, ..., 0, 0, 0], diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index cadde3e8..07fdf34f 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -8,6 +8,7 @@ import sys import tempfile import traceback +import warnings from enum import Enum from io import IOBase from pathlib import Path @@ -218,7 +219,7 @@ def __str__(self) -> str: stdout_str = "\n".join(self.stdout) stderr_str = "\n".join(self.stderr) return _remove_escape_and_color_codes( - f"stdout:\n{stdout_str}\nstderr:\n{stderr_str}" + f"----- stdout -----\n{stdout_str}\n----- stderr -----\n{stderr_str}" ) @@ -263,21 +264,19 @@ def text(self, include_logs: bool = True) -> str: """ Returns the text representation of this object, i.e. including the main result or the error traceback, optionally along with the logs (stdout, stderr). """ - prefix = ( - "\n".join(self.logs.stdout) + "\n".join(self.logs.stderr) - if include_logs - else "" - ) + prefix = str(self.logs) if include_logs else "" if self.error: - return prefix + "\n" + self.error.traceback - return next( + return prefix + "\n----- Error -----\n" + self.error.traceback + + result_str = [ ( - prefix + "\n" + (res.text or "") - for res in self.results + f"----- Final output -----\n{res.text}" if res.is_main_result - ), - prefix, - ) + else f"----- Intermediate output-----\n{res.text}" + ) + for res in self.results + ] + return prefix + "\n" + "\n".join(result_str) @property def success(self) -> bool: @@ -404,7 +403,7 @@ def restart_kernel(self) -> None: self.interpreter.notebook.restart_kernel() def exec_cell(self, code: str) -> Execution: - execution = self.interpreter.notebook.exec_cell(code) + execution = self.interpreter.notebook.exec_cell(code, timeout=self.timeout) return Execution.from_e2b_execution(execution) def upload_file(self, file: Union[str, Path, IO]) -> str: @@ -508,16 +507,24 @@ class CodeInterpreterFactory: @staticmethod def get_default_instance() -> CodeInterpreter: + warnings.warn( + "Use new_instance() instead for production usage, get_default_instance() is for testing and will be removed in the future." + ) inst_map = CodeInterpreterFactory._instance_map instance = inst_map.get(CodeInterpreterFactory._default_key) if instance: return instance + instance = CodeInterpreterFactory.new_instance() + inst_map[CodeInterpreterFactory._default_key] = instance + return instance + + @staticmethod + def new_instance() -> CodeInterpreter: if os.getenv("CODE_SANDBOX_RUNTIME") == "e2b": - instance = E2BCodeInterpreter(timeout=600) - atexit.register(instance.close) + instance: CodeInterpreter = E2BCodeInterpreter(timeout=600) else: instance = LocalCodeInterpreter(timeout=600) - inst_map[CodeInterpreterFactory._default_key] = instance + atexit.register(instance.close) return instance