Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle edge cases gracefully in extract_frames_from_video #269

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed tests/data/video/test.mp4
Binary file not shown.
59 changes: 56 additions & 3 deletions tests/unit/tools/test_video.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,63 @@
import tempfile
from typing import Optional

import cv2
import numpy as np

from vision_agent.utils.video import extract_frames_from_video


def test_extract_frames_from_video():
# TODO: consider generating a video on the fly instead
video_path = "tests/data/video/test.mp4"

video_path = _create_video(duration=2)
# there are 48 frames at 24 fps in this video file
res = extract_frames_from_video(video_path, fps=24)
assert len(res) == 48

res = extract_frames_from_video(video_path, fps=2)
assert len(res) == 4

res = extract_frames_from_video(video_path, fps=1)
assert len(res) == 2


def test_extract_frames_from_invalid_uri():
uri = "https://www.youtube.com/watch?v=HjGJvNRkuqY&ab_channel=TheSAHDStudio"
res = extract_frames_from_video(uri, 1.0)
assert len(res) == 0


def test_extract_frames_with_illegal_fps():
video_path = _create_video(duration=1)
res = extract_frames_from_video(video_path, -1.0)
assert len(res) == 1

res = extract_frames_from_video(video_path, None)
assert len(res) == 1

res = extract_frames_from_video(video_path, 0.0)
assert len(res) == 1


def test_extract_frames_with_input_video_has_no_fps():
video_path = _create_video(fps_video_prop=None)
res = extract_frames_from_video(video_path, 1.0)
assert len(res) == 0


def _create_video(
*, duration: int = 3, fps: int = 24, fps_video_prop: Optional[int] = 24
) -> str:
# Create a temporary file for the video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
video_path = temp_video.name
# Set video properties
width, height = 640, 480
# Create a VideoWriter object without setting FPS
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(video_path, fourcc, fps_video_prop, (width, height))
# Generate and write random frames
for _ in range(duration * fps):
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
out.write(frame)
out.release()
return video_path
26 changes: 23 additions & 3 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
_LOGGER = logging.getLogger(__name__)
# The maximum length of the clip to extract frames from, in seconds

_DEFAULT_VIDEO_FPS = 24
_DEFAULT_INPUT_FPS = 1.0


def play_video(video_base64: str) -> None:
"""Play a video file"""
Expand Down Expand Up @@ -51,7 +54,9 @@ def _resize_frame(frame: np.ndarray) -> np.ndarray:


def video_writer(
frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None
frames: List[np.ndarray],
fps: float = _DEFAULT_INPUT_FPS,
filename: Optional[str] = None,
) -> str:
if filename is None:
filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
Expand All @@ -78,7 +83,7 @@ def video_writer(


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 1.0, file_ext: str = ".mp4"
frames: List[np.ndarray], fps: float = _DEFAULT_INPUT_FPS, file_ext: str = ".mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.

Expand All @@ -101,7 +106,7 @@ def frames_to_bytes(
# same file name and the time savings are very large.
@lru_cache(maxsize=8)
def extract_frames_from_video(
video_uri: str, fps: float = 1.0
video_uri: str, fps: float = _DEFAULT_INPUT_FPS
) -> List[Tuple[np.ndarray, float]]:
"""Extract frames from a video along with the timestamp in seconds.

Expand All @@ -118,6 +123,16 @@ def extract_frames_from_video(

cap = cv2.VideoCapture(video_uri)
orig_fps = cap.get(cv2.CAP_PROP_FPS)
if not orig_fps or orig_fps <= 0:
_LOGGER.warning(
f"Input video, {video_uri}, has no fps, using the default value {_DEFAULT_VIDEO_FPS}"
)
orig_fps = _DEFAULT_VIDEO_FPS
if not fps or fps <= 0:
_LOGGER.warning(
f"Input fps, {fps}, is illegal, using the default value: {_DEFAULT_INPUT_FPS}"
)
fps = _DEFAULT_INPUT_FPS
orig_frame_time = 1 / orig_fps
targ_frame_time = 1 / fps
frames: List[Tuple[np.ndarray, float]] = []
Expand All @@ -129,10 +144,15 @@ def extract_frames_from_video(
break

elapsed_time += orig_frame_time
# This is to prevent float point precision loss issue, which can cause
# the elapsed time to be slightly less than the target frame time, which
# causes the last frame to be skipped
elapsed_time = round(elapsed_time, 8)
if elapsed_time >= targ_frame_time:
frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps))
elapsed_time -= targ_frame_time

i += 1
cap.release()
_LOGGER.info(f"Extracted {len(frames)} frames from {video_uri}")
return frames
Loading