From cff0756a28b1510cdef4ebb489a62677cc1c6a8d Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Wed, 5 Jun 2024 13:03:29 -0700 Subject: [PATCH] Minor improvements and fixes for code generation (#114) 1. Add save_video tool for better video serialization 2. Save image as an intermediate output when save_image is called 3. Add default imports from typing. 4. Keep the e2b sandbox alive for 5min when it runs a command. --- examples/custom_tools/run_custom_tool.py | 2 +- poetry.lock | 6 +-- vision_agent/agent/vision_agent.py | 6 ++- vision_agent/tools/__init__.py | 2 +- vision_agent/tools/tools.py | 61 ++++++++++++++---------- vision_agent/utils/execute.py | 3 ++ vision_agent/utils/video.py | 2 - 7 files changed, 49 insertions(+), 33 deletions(-) diff --git a/examples/custom_tools/run_custom_tool.py b/examples/custom_tools/run_custom_tool.py index 1e61ab6e..fb0cc5c0 100644 --- a/examples/custom_tools/run_custom_tool.py +++ b/examples/custom_tools/run_custom_tool.py @@ -1,7 +1,7 @@ import numpy as np +from template_match import template_matching_with_rotation import vision_agent as va -from template_match import template_matching_with_rotation from vision_agent.utils.image_utils import get_image_size, normalize_bbox diff --git a/poetry.lock b/poetry.lock index b81ac7e6..a76984bd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1805,11 +1805,11 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -1929,8 +1929,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 27880b7f..142cdf15 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -36,7 +36,11 @@ _LOGGER = logging.getLogger(__name__) _MAX_TABULATE_COL_WIDTH = 80 _CONSOLE = Console() -_DEFAULT_IMPORT = "\n".join(T.__new_tools__) +_DEFAULT_IMPORT = "\n".join(T.__new_tools__) + "\n".join( + [ + "from typing import *", + ] +) def get_diff(before: str, after: str) -> str: diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 6136d937..ccffa737 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -22,7 +22,7 @@ overlay_segmentation_masks, save_image, save_json, - save_video_to_result, + save_video, visual_prompt_counting, zero_shot_counting, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 11e219d0..a2c76156 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -5,12 +5,13 @@ import tempfile from importlib import resources from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import cv2 import numpy as np import pandas as pd import requests +from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from vision_agent.tools.tool_utils import _send_inference_request @@ -545,24 +546,49 @@ def save_image(image: np.ndarray) -> str: >>> save_image(image) "/tmp/tmpabc123.png" """ + from IPython.display import display + pil_image = Image.fromarray(image.astype(np.uint8)) + display(pil_image) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: - pil_image = Image.fromarray(image.astype(np.uint8)) pil_image.save(f, "PNG") return f.name -def save_video_to_result(video_uri: str) -> None: - """'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output). - This function is required to run if user wants to visualize the video generated by the code. +def save_video( + frames: List[np.ndarray], output_video_path: Optional[str] = None, fps: float = 4 +) -> str: + """'save_video' is a utility function that saves a list of frames as a mp4 video file on disk. Parameters: - video_uri (str): The URI to the video file. Currently only local file paths are supported. + frames (list[np.ndarray]): A list of frames to save. + output_video_path (str): The path to save the video file. If not provided, a temporary file will be created. + fps (float): The number of frames composes a second in the video. + + Returns: + str: The path to the saved video file. Example ------- - >>> save_video_to_result("path/to/video.mp4") + >>> save_video(frames) + "/tmp/tmpvideo123.mp4" """ + if fps <= 0: + _LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).") + fps = 4 + with ImageSequenceClip(frames, fps=fps) as video: + if output_video_path: + f = open(output_video_path, "wb") + else: + f = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) # type: ignore + video.write_videofile(f.name, codec="libx264") + f.close() + _save_video_to_result(f.name) + return f.name + + +def _save_video_to_result(video_uri: str) -> None: + """Saves a video into the result of the code execution (as an intermediate output).""" from IPython.display import display serializer = FileSerializer(video_uri) @@ -595,8 +621,6 @@ def overlay_bounding_boxes( image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}], ) """ - from IPython.display import display - pil_image = Image.fromarray(image.astype(np.uint8)) if len(set([box["label"] for box in bboxes])) > len(COLORS): @@ -634,9 +658,6 @@ def overlay_bounding_boxes( text_box = draw.textbbox((box[0], box[1]), text=text, font=font) draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label]) draw.text((box[0], box[1]), text, fill="black", font=font) - - pil_image = pil_image.convert("RGB") - display(pil_image) return np.array(pil_image) @@ -668,8 +689,6 @@ def overlay_segmentation_masks( }], ) """ - from IPython.display import display - pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA") if len(set([mask["label"] for mask in masks])) > len(COLORS): @@ -690,9 +709,6 @@ def overlay_segmentation_masks( np_mask[mask > 0, :] = color[label] + (255 * 0.5,) mask_img = Image.fromarray(np_mask.astype(np.uint8)) pil_image = Image.alpha_composite(pil_image, mask_img) - - pil_image = pil_image.convert("RGB") - display(pil_image) return np.array(pil_image) @@ -723,8 +739,6 @@ def overlay_heat_map( }, ) """ - from IPython.display import display - pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0: @@ -740,10 +754,7 @@ def overlay_heat_map( combined = Image.alpha_composite( pil_image.convert("RGBA"), overlay.resize(pil_image.size) ) - - pil_image = combined.convert("RGB") - display(pil_image) - return np.array(pil_image) + return np.array(combined) def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str: @@ -805,7 +816,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: save_json, load_image, save_image, - save_video_to_result, + save_video, overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, @@ -818,7 +829,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: save_json, load_image, save_image, - save_video_to_result, + save_video, overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index ad9468e8..f38a123d 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -401,6 +401,8 @@ def download_file(self, file_path: str) -> Path: class E2BCodeInterpreter(CodeInterpreter): + KEEP_ALIVE_SEC: int = 300 + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert os.getenv("E2B_API_KEY"), "E2B_API_KEY environment variable must be set" @@ -432,6 +434,7 @@ def restart_kernel(self) -> None: retry=tenacity.retry_if_exception_type(TimeoutError), ) def exec_cell(self, code: str) -> Execution: + self.interpreter.keep_alive(E2BCodeInterpreter.KEEP_ALIVE_SEC) execution = self.interpreter.notebook.exec_cell(code, timeout=self.timeout) return Execution.from_e2b_execution(execution) diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index bd04ac8c..6dbb28de 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -31,7 +31,6 @@ def play_video(video_base64: str) -> None: # Display the first frame and wait for any key press to start the video ret, frame = cap.read() if ret: - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) cv2.imshow("Video Player", frame) _LOGGER.info(f"Press any key to start playing the video: {temp_video_path}") cv2.waitKey(0) # Wait for any key press @@ -40,7 +39,6 @@ def play_video(video_base64: str) -> None: ret, frame = cap.read() if not ret: break - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) cv2.imshow("Video Player", frame) # Press 'q' to exit the video if cv2.waitKey(200) & 0xFF == ord("q"):