Skip to content

Commit

Permalink
Support video as an intermediate output and visualization (#112)
Browse files Browse the repository at this point in the history
* Support video as an intermediate output and local video visualization
  • Loading branch information
humpydonkey authored Jun 5, 2024
1 parent f3d5bcf commit 56764af
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 17 deletions.
26 changes: 13 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pandas = "2.*"
openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
opencv-python-headless = "4.*"
opencv-python = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
scipy = "1.13.*"
Expand Down
4 changes: 4 additions & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import Sim
from vision_agent.utils.video import play_video

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -522,6 +523,9 @@ def chat_with_workflow(
for res in execution_result.results:
if res.png:
b64_to_pil(res.png).show()
if res.mp4:
play_video(res.mp4)

return {
"code": code,
"test": test,
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
overlay_segmentation_masks,
save_image,
save_json,
save_video_to_result,
visual_prompt_counting,
zero_shot_counting,
)
Expand Down
47 changes: 44 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
b64_to_pil,
convert_to_b64,
Expand Down Expand Up @@ -550,6 +551,29 @@ def save_image(image: np.ndarray) -> str:
return f.name


def save_video_to_result(video_uri: str) -> None:
"""'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output).
This function is required to run if user wants to visualize the video generated by the code.
Parameters:
video_uri (str): The URI to the video file. Currently only local file paths are supported.
Example
-------
>>> save_video_to_result("path/to/video.mp4")
"""
from IPython.display import display

serializer = FileSerializer(video_uri)
display(
{
MimeType.VIDEO_MP4_B64: serializer.base64(),
MimeType.TEXT_PLAIN: str(serializer),
},
raw=True,
)


def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> np.ndarray:
Expand All @@ -570,6 +594,8 @@ def overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8))

if len(set([box["label"] for box in bboxes])) > len(COLORS):
Expand Down Expand Up @@ -606,7 +632,10 @@ def overlay_bounding_boxes(
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return np.array(pil_image.convert("RGB"))

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


def overlay_segmentation_masks(
Expand Down Expand Up @@ -637,6 +666,8 @@ def overlay_segmentation_masks(
}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")

if len(set([mask["label"] for mask in masks])) > len(COLORS):
Expand All @@ -656,7 +687,10 @@ def overlay_segmentation_masks(
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)
return np.array(pil_image.convert("RGB"))

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


def overlay_heat_map(
Expand Down Expand Up @@ -686,6 +720,8 @@ def overlay_heat_map(
},
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
Expand All @@ -701,7 +737,10 @@ def overlay_heat_map(
combined = Image.alpha_composite(
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
)
return np.array(combined.convert("RGB"))

pil_image = combined.convert("RGB")
display(pil_image)
return np.array(pil_image)


def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
Expand Down Expand Up @@ -763,6 +802,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand All @@ -775,6 +815,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand Down
24 changes: 24 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import atexit
import base64
import copy
import logging
import os
Expand Down Expand Up @@ -45,12 +46,31 @@ class MimeType(str, Enum):
IMAGE_SVG = "image/svg+xml"
IMAGE_PNG = "image/png"
IMAGE_JPEG = "image/jpeg"
VIDEO_MP4_B64 = "video/mp4/base64"
APPLICATION_PDF = "application/pdf"
TEXT_LATEX = "text/latex"
APPLICATION_JSON = "application/json"
APPLICATION_JAVASCRIPT = "application/javascript"


class FileSerializer:
"""Adaptor class that allows IPython.display.display() to serialize a file to a base64 string representation."""

def __init__(self, file_uri: str):
self.video_uri = file_uri
assert os.path.isfile(
file_uri
), f"Only support local files currently: {file_uri}"
assert Path(file_uri).exists(), f"File not found: {file_uri}"

def __repr__(self) -> str:
return f"FileSerializer({self.video_uri})"

def base64(self) -> str:
with open(self.video_uri, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")


class Result:
"""
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
Expand All @@ -70,6 +90,7 @@ class Result:
png: Optional[str] = None
jpeg: Optional[str] = None
pdf: Optional[str] = None
mp4: Optional[str] = None
latex: Optional[str] = None
json: Optional[Dict[str, Any]] = None
javascript: Optional[str] = None
Expand All @@ -93,6 +114,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
self.png = data.pop(MimeType.IMAGE_PNG, None)
self.jpeg = data.pop(MimeType.IMAGE_JPEG, None)
self.pdf = data.pop(MimeType.APPLICATION_PDF, None)
self.mp4 = data.pop(MimeType.VIDEO_MP4_B64, None)
self.latex = data.pop(MimeType.TEXT_LATEX, None)
self.json = data.pop(MimeType.APPLICATION_JSON, None)
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
Expand Down Expand Up @@ -190,6 +212,8 @@ def formats(self) -> Iterable[str]:
formats.append("json")
if self.javascript:
formats.append("javascript")
if self.mp4:
formats.append("mp4")
if self.extra:
formats.extend(iter(self.extra))
return formats
Expand Down
35 changes: 35 additions & 0 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import logging
import math
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
import tempfile
from typing import List, Tuple, cast

import cv2
Expand All @@ -14,6 +16,39 @@
_CLIP_LENGTH = 30.0


def play_video(video_base64: str) -> None:
"""Play a video file"""
video_data = base64.b64decode(video_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(video_data)
temp_video_path = temp_video.name

cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
_LOGGER.error("Error: Could not open video.")
return

# Display the first frame and wait for any key press to start the video
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
_LOGGER.info(f"Press any key to start playing the video: {temp_video_path}")
cv2.waitKey(0) # Wait for any key press

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
# Press 'q' to exit the video
if cv2.waitKey(200) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()


def extract_frames_from_video(
video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
) -> List[Tuple[np.ndarray, float]]:
Expand Down

0 comments on commit 56764af

Please sign in to comment.