Skip to content

Commit

Permalink
organized video writing
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 5, 2024
1 parent 92e2667 commit bac5299
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 36 deletions.
27 changes: 13 additions & 14 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import cv2
import numpy as np
import requests
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore
Expand All @@ -35,7 +34,6 @@
ODResponseData,
PromptTask,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
Expand All @@ -44,13 +42,17 @@
convert_to_b64,
denormalize_bbox,
encode_image_bytes,
frames_to_bytes,
get_image_size,
normalize_bbox,
numpy_to_bytes,
rle_decode,
rle_decode_array,
)
from vision_agent.utils.video import (
extract_frames_from_video,
frames_to_bytes,
video_writer,
)

register_heif_opener()

Expand Down Expand Up @@ -1513,17 +1515,14 @@ def save_video(
"/tmp/tmpvideo123.mp4"
"""
if fps <= 0:
_LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).")
fps = 4
with ImageSequenceClip(frames, fps=fps) as video:
if output_video_path:
f = open(output_video_path, "wb")
else:
f = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) # type: ignore
video.write_videofile(f.name, codec="libx264")
f.close()
_save_video_to_result(f.name)
return f.name
raise ValueError(f"fps must be greater than 0 got {fps}")

if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(delete=False).name

output_video_path = video_writer(frames, fps, output_video_path)
_save_video_to_result(output_video_path)
return output_video_path


def _save_video_to_result(video_uri: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
Result,
)
from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim
from .video import extract_frames_from_video
from .video import extract_frames_from_video, video_writer
21 changes: 1 addition & 20 deletions vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from typing import Dict, List, Tuple, Union

import numpy as np
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

from vision_agent.utils import extract_frames_from_video
from vision_agent.utils import extract_frames_from_video, video_writer

COLORS = [
(158, 218, 229),
Expand Down Expand Up @@ -90,24 +89,6 @@ def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
return binary_mask


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Parameters:
frames: the list of frames
fps: the frames per second of the video
file_ext: the file extension of the video file
"""
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
clip = ImageSequenceClip(frames, fps=fps)
clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps, codec="libx264")
with open(temp_file.name + f".{file_ext}", "rb") as f:
buffer_bytes = f.read()
return buffer_bytes


def b64_to_pil(b64_str: str) -> ImageType:
r"""Convert a base64 string to a PIL Image.
Expand Down
35 changes: 34 additions & 1 deletion vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import tempfile
from functools import lru_cache
from typing import List, Tuple
from typing import List, Optional, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -43,6 +43,39 @@ def play_video(video_base64: str) -> None:
cv2.destroyAllWindows()


def video_writer(
frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None
) -> str:
if filename is None:
filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
height, width = frames[0].shape[:2]
writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
return filename


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 10, file_ext: str = ".mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Parameters:
frames: the list of frames
fps: the frames per second of the video
file_ext: the file extension of the video file
"""
with tempfile.NamedTemporaryFile(delete=True, suffix=file_ext) as temp_file:
video_writer(frames, fps, temp_file.name)

with open(temp_file.name, "rb") as f:
buffer_bytes = f.read()
return buffer_bytes


@lru_cache(maxsize=8)
def extract_frames_from_video(
video_uri: str, fps: float = 1.0
Expand Down

0 comments on commit bac5299

Please sign in to comment.