Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support video as an intermediate output and visualization #112

Merged
merged 6 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pandas = "2.*"
openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
opencv-python-headless = "4.*"
opencv-python = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
scipy = "1.13.*"
Expand Down
4 changes: 4 additions & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vision_agent.utils.execute import CodeInterpreter
from vision_agent.utils.image_utils import b64_to_pil
from vision_agent.utils.sim import Sim
from vision_agent.utils.video import play_video

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -522,6 +523,9 @@ def chat_with_workflow(
for res in execution_result.results:
if res.png:
b64_to_pil(res.png).show()
if res.mp4:
play_video(res.mp4)

return {
"code": code,
"test": test,
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
overlay_segmentation_masks,
save_image,
save_json,
save_video_to_result,
visual_prompt_counting,
zero_shot_counting,
)
Expand Down
47 changes: 44 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
b64_to_pil,
convert_to_b64,
Expand Down Expand Up @@ -550,6 +551,29 @@ def save_image(image: np.ndarray) -> str:
return f.name


def save_video_to_result(video_uri: str) -> None:
"""'save_video_to_result' a utility function that saves a video into the result of the code execution (as an intermediate output).
This function is required to run if user wants to visualize the video generated by the code.

Parameters:
video_uri (str): The URI to the video file. Currently only local file paths are supported.

Example
-------
>>> save_video_to_result("path/to/video.mp4")
"""
from IPython.display import display

serializer = FileSerializer(video_uri)
display(
{
MimeType.VIDEO_MP4_B64: serializer.base64(),
MimeType.TEXT_PLAIN: str(serializer),
},
raw=True,
)


def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> np.ndarray:
Expand All @@ -570,6 +594,8 @@ def overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8))

if len(set([box["label"] for box in bboxes])) > len(COLORS):
Expand Down Expand Up @@ -606,7 +632,10 @@ def overlay_bounding_boxes(
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return np.array(pil_image.convert("RGB"))

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


def overlay_segmentation_masks(
Expand Down Expand Up @@ -637,6 +666,8 @@ def overlay_segmentation_masks(
}],
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")

if len(set([mask["label"] for mask in masks])) > len(COLORS):
Expand All @@ -656,7 +687,10 @@ def overlay_segmentation_masks(
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)
return np.array(pil_image.convert("RGB"))

pil_image = pil_image.convert("RGB")
display(pil_image)
return np.array(pil_image)


def overlay_heat_map(
Expand Down Expand Up @@ -686,6 +720,8 @@ def overlay_heat_map(
},
)
"""
from IPython.display import display

pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
Expand All @@ -701,7 +737,10 @@ def overlay_heat_map(
combined = Image.alpha_composite(
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
)
return np.array(combined.convert("RGB"))

pil_image = combined.convert("RGB")
display(pil_image)
return np.array(pil_image)


def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
Expand Down Expand Up @@ -763,6 +802,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand All @@ -775,6 +815,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
save_json,
load_image,
save_image,
save_video_to_result,
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
Expand Down
24 changes: 24 additions & 0 deletions vision_agent/utils/execute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import atexit
import base64
import copy
import logging
import os
Expand Down Expand Up @@ -45,12 +46,31 @@ class MimeType(str, Enum):
IMAGE_SVG = "image/svg+xml"
IMAGE_PNG = "image/png"
IMAGE_JPEG = "image/jpeg"
VIDEO_MP4_B64 = "video/mp4/base64"
APPLICATION_PDF = "application/pdf"
TEXT_LATEX = "text/latex"
APPLICATION_JSON = "application/json"
APPLICATION_JAVASCRIPT = "application/javascript"


class FileSerializer:
"""Adaptor class that allows IPython.display.display() to serialize a file to a base64 string representation."""

def __init__(self, file_uri: str):
self.video_uri = file_uri
assert os.path.isfile(
file_uri
), f"Only support local files currently: {file_uri}"
assert Path(file_uri).exists(), f"File not found: {file_uri}"

def __repr__(self) -> str:
return f"FileSerializer({self.video_uri})"

def base64(self) -> str:
with open(self.video_uri, "rb") as file:
return base64.b64encode(file.read()).decode("utf-8")


class Result:
"""
Represents the data to be displayed as a result of executing a cell in a Jupyter notebook.
Expand All @@ -70,6 +90,7 @@ class Result:
png: Optional[str] = None
jpeg: Optional[str] = None
pdf: Optional[str] = None
mp4: Optional[str] = None
latex: Optional[str] = None
json: Optional[Dict[str, Any]] = None
javascript: Optional[str] = None
Expand All @@ -93,6 +114,7 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]):
self.png = data.pop(MimeType.IMAGE_PNG, None)
self.jpeg = data.pop(MimeType.IMAGE_JPEG, None)
self.pdf = data.pop(MimeType.APPLICATION_PDF, None)
self.mp4 = data.pop(MimeType.VIDEO_MP4_B64, None)
self.latex = data.pop(MimeType.TEXT_LATEX, None)
self.json = data.pop(MimeType.APPLICATION_JSON, None)
self.javascript = data.pop(MimeType.APPLICATION_JAVASCRIPT, None)
Expand Down Expand Up @@ -190,6 +212,8 @@ def formats(self) -> Iterable[str]:
formats.append("json")
if self.javascript:
formats.append("javascript")
if self.mp4:
formats.append("mp4")
if self.extra:
formats.extend(iter(self.extra))
return formats
Expand Down
35 changes: 35 additions & 0 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import logging
import math
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
import tempfile
from typing import List, Tuple, cast

import cv2
Expand All @@ -14,6 +16,39 @@
_CLIP_LENGTH = 30.0


def play_video(video_base64: str) -> None:
"""Play a video file"""
video_data = base64.b64decode(video_base64)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_video:
temp_video.write(video_data)
temp_video_path = temp_video.name

cap = cv2.VideoCapture(temp_video_path)
if not cap.isOpened():
_LOGGER.error("Error: Could not open video.")
return

# Display the first frame and wait for any key press to start the video
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
_LOGGER.info(f"Press any key to start playing the video: {temp_video_path}")
cv2.waitKey(0) # Wait for any key press

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
cv2.imshow("Video Player", frame)
# Press 'q' to exit the video
if cv2.waitKey(200) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()


def extract_frames_from_video(
video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
) -> List[Tuple[np.ndarray, float]]:
Expand Down
Loading