From 56764af4880201ee4f81d34c61e849355eb2777a Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:19:43 -0700 Subject: [PATCH] Support video as an intermediate output and visualization (#112) * Support video as an intermediate output and local video visualization --- poetry.lock | 26 ++++++++--------- pyproject.toml | 2 +- vision_agent/agent/vision_agent.py | 4 +++ vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 47 ++++++++++++++++++++++++++++-- vision_agent/utils/execute.py | 24 +++++++++++++++ vision_agent/utils/video.py | 35 ++++++++++++++++++++++ 7 files changed, 122 insertions(+), 17 deletions(-) diff --git a/poetry.lock b/poetry.lock index af2d42bd..b81ac7e6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1787,29 +1787,29 @@ typing-extensions = ">=4.7,<5" datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] -name = "opencv-python-headless" -version = "4.9.0.80" +name = "opencv-python" +version = "4.10.0.82" description = "Wrapper package for OpenCV python bindings." optional = false python-versions = ">=3.6" files = [ - {file = "opencv-python-headless-4.9.0.80.tar.gz", hash = "sha256:71a4cd8cf7c37122901d8e81295db7fb188730e33a0e40039a4e59c1030b0958"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:2ea8a2edc4db87841991b2fbab55fc07b97ecb602e0f47d5d485bd75cee17c1a"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e0ee54e27be493e8f7850847edae3128e18b540dac1d7b2e4001b8944e11e1c6"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57ce2865e8fec431c6f97a81e9faaf23fa5be61011d0a75ccf47a3c0d65fa73d"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:976656362d68d9f40a5c66f83901430538002465f7db59142784f3893918f3df"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-win32.whl", hash = "sha256:11e3849d83e6651d4e7699aadda9ec7ed7c38957cbbcb99db074f2a2d2de9670"}, - {file = "opencv_python_headless-4.9.0.80-cp37-abi3-win_amd64.whl", hash = "sha256:a8056c2cb37cd65dfcdf4153ca16f7362afcf3a50d600d6bb69c660fc61ee29c"}, + {file = "opencv-python-4.10.0.82.tar.gz", hash = "sha256:dbc021eaa310c4145c47cd648cb973db69bb5780d6e635386cd53d3ea76bd2d5"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:5f78652339957ec24b80a782becfb32f822d2008a865512121fad8c3ce233e9a"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:e6be19a0615aa8c4e0d34e0c7b133e26e386f4b7e9b557b69479104ab2c133ec"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b49e530f7fd86f671514b39ffacdf5d14ceb073bc79d0de46bbb6b0cad78eaf"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:955c5ce8ac90c9e4636ad7f5c0d9c75b80abbe347182cfd09b0e3eec6e50472c"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-win32.whl", hash = "sha256:ff54adc9e4daaf438e669664af08bec4a268c7b7356079338b8e4fae03810f2c"}, + {file = "opencv_python-4.10.0.82-cp37-abi3-win_amd64.whl", hash = "sha256:2e3c2557b176f1e528417520a52c0600a92c1bb1c359f3df8e6411ab4293f065"}, ] [package.dependencies] numpy = [ {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -1929,8 +1929,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -3561,4 +3561,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "f5e93a86c3148808cdb265dec3db5310d5d2d3240506bf75a2770bb41eadf1d6" +content-hash = "f83b2f518eb15325260c63eb90a84e54d70c85b047994e281659409bad3ef49d" diff --git a/pyproject.toml b/pyproject.toml index 6f21547b..cb7e48ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ pandas = "2.*" openai = "1.*" typing_extensions = "4.*" moviepy = "1.*" -opencv-python-headless = "4.*" +opencv-python = "4.*" tabulate = "^0.9.0" pydantic-settings = "^2.2.1" scipy = "1.13.*" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 254949b9..013eb950 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -28,6 +28,7 @@ from vision_agent.utils.execute import CodeInterpreter from vision_agent.utils.image_utils import b64_to_pil from vision_agent.utils.sim import Sim +from vision_agent.utils.video import play_video logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) @@ -522,6 +523,9 @@ def chat_with_workflow( for res in execution_result.results: if res.png: b64_to_pil(res.png).show() + if res.mp4: + play_video(res.mp4) + return { "code": code, "test": test, diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 64a277d9..6136d937 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -22,6 +22,7 @@ overlay_segmentation_masks, save_image, save_json, + save_video_to_result, visual_prompt_counting, zero_shot_counting, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3a5c0c51..be62787b 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -15,6 +15,7 @@ from vision_agent.tools.tool_utils import _send_inference_request from vision_agent.utils import extract_frames_from_video +from vision_agent.utils.execute import FileSerializer, MimeType from vision_agent.utils.image_utils import ( b64_to_pil, convert_to_b64, @@ -550,6 +551,29 @@ def save_image(image: np.ndarray) -> str: return f.name +def save_video_to_result(video_uri: str) -> None: + """'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output). + This function is required to run if user wants to visualize the video generated by the code. + + Parameters: + video_uri (str): The URI to the video file. Currently only local file paths are supported. + + Example + ------- + >>> save_video_to_result("path/to/video.mp4") + """ + from IPython.display import display + + serializer = FileSerializer(video_uri) + display( + { + MimeType.VIDEO_MP4_B64: serializer.base64(), + MimeType.TEXT_PLAIN: str(serializer), + }, + raw=True, + ) + + def overlay_bounding_boxes( image: np.ndarray, bboxes: List[Dict[str, Any]] ) -> np.ndarray: @@ -570,6 +594,8 @@ def overlay_bounding_boxes( image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}], ) """ + from IPython.display import display + pil_image = Image.fromarray(image.astype(np.uint8)) if len(set([box["label"] for box in bboxes])) > len(COLORS): @@ -606,7 +632,10 @@ def overlay_bounding_boxes( text_box = draw.textbbox((box[0], box[1]), text=text, font=font) draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label]) draw.text((box[0], box[1]), text, fill="black", font=font) - return np.array(pil_image.convert("RGB")) + + pil_image = pil_image.convert("RGB") + display(pil_image) + return np.array(pil_image) def overlay_segmentation_masks( @@ -637,6 +666,8 @@ def overlay_segmentation_masks( }], ) """ + from IPython.display import display + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA") if len(set([mask["label"] for mask in masks])) > len(COLORS): @@ -656,7 +687,10 @@ def overlay_segmentation_masks( np_mask[mask > 0, :] = color[label] + (255 * 0.5,) mask_img = Image.fromarray(np_mask.astype(np.uint8)) pil_image = Image.alpha_composite(pil_image, mask_img) - return np.array(pil_image.convert("RGB")) + + pil_image = pil_image.convert("RGB") + display(pil_image) + return np.array(pil_image) def overlay_heat_map( @@ -686,6 +720,8 @@ def overlay_heat_map( }, ) """ + from IPython.display import display + pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0: @@ -701,7 +737,10 @@ def overlay_heat_map( combined = Image.alpha_composite( pil_image.convert("RGBA"), overlay.resize(pil_image.size) ) - return np.array(combined.convert("RGB")) + + pil_image = combined.convert("RGB") + display(pil_image) + return np.array(pil_image) def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str: @@ -763,6 +802,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: save_json, load_image, save_image, + save_video_to_result, overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, @@ -775,6 +815,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame: save_json, load_image, save_image, + save_video_to_result, overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index 07fdf34f..39df5bde 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -1,5 +1,6 @@ import abc import atexit +import base64 import copy import logging import os @@ -45,12 +46,31 @@ class MimeType(str, Enum): IMAGE_SVG = "image/svg+xml" IMAGE_PNG = "image/png" IMAGE_JPEG = "image/jpeg" + VIDEO_MP4_B64 = "video/mp4/base64" APPLICATION_PDF = "application/pdf" TEXT_LATEX = "text/latex" APPLICATION_JSON = "application/json" APPLICATION_JAVASCRIPT = "application/javascript" +class FileSerializer: + """Adaptor class that allows IPython.display.display() to serialize a file to a base64 string representation.""" + + def __init__(self, file_uri: str): + self.video_uri = file_uri + assert os.path.isfile( + file_uri + ), f"Only support local files currently: {file_uri}" + assert Path(file_uri).exists(), f"File not found: {file_uri}" + + def __repr__(self) -> str: + return f"FileSerializer({self.video_uri})" + + def base64(self) -> str: + with open(self.video_uri, "rb") as file: + return base64.b64encode(file.read()).decode("utf-8") + + class Result: """ Represents the data to be displayed as a result of executing a cell in a Jupyter notebook. @@ -70,6 +90,7 @@ class Result: png: Optional[str] = None jpeg: Optional[str] = None pdf: Optional[str] = None + mp4: Optional[str] = None latex: Optional[str] = None json: Optional[Dict[str, Any]] = None javascript: Optional[str] = None @@ -93,6 +114,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]): self.png = data.pop(MimeType.IMAGE_PNG, None) self.jpeg = data.pop(MimeType.IMAGE_JPEG, None) self.pdf = data.pop(MimeType.APPLICATION_PDF, None) + self.mp4 = data.pop(MimeType.VIDEO_MP4_B64, None) self.latex = data.pop(MimeType.TEXT_LATEX, None) self.json = data.pop(MimeType.APPLICATION_JSON, None) self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None) @@ -190,6 +212,8 @@ def formats(self) -> Iterable[str]: formats.append("json") if self.javascript: formats.append("javascript") + if self.mp4: + formats.append("mp4") if self.extra: formats.extend(iter(self.extra)) return formats diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index 4eca66af..93c8c7fb 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -1,7 +1,9 @@ +import base64 import logging import math import os from concurrent.futures import ProcessPoolExecutor, as_completed +import tempfile from typing import List, Tuple, cast import cv2 @@ -14,6 +16,39 @@ _CLIP_LENGTH = 30.0 +def play_video(video_base64: str) -> None: + """Play a video file""" + video_data = base64.b64decode(video_base64) + with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video: + temp_video.write(video_data) + temp_video_path = temp_video.name + + cap = cv2.VideoCapture(temp_video_path) + if not cap.isOpened(): + _LOGGER.error("Error: Could not open video.") + return + + # Display the first frame and wait for any key press to start the video + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + cv2.imshow("Video Player", frame) + _LOGGER.info(f"Press any key to start playing the video: {temp_video_path}") + cv2.waitKey(0) # Wait for any key press + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + cv2.imshow("Video Player", frame) + # Press 'q' to exit the video + if cv2.waitKey(200) & 0xFF == ord("q"): + break + cap.release() + cv2.destroyAllWindows() + + def extract_frames_from_video( video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0 ) -> List[Tuple[np.ndarray, float]]: