Skip to content

Commit

Permalink
fix: handle edge cases gracefully in extract_frames_from_video (#269)
Browse files Browse the repository at this point in the history
* fix: handle edge cases gracefully in extract_frames_from_video

* Fix lint error
  • Loading branch information
humpydonkey authored Oct 14, 2024
1 parent b19cc3b commit f0a15ed
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 6 deletions.
Binary file removed tests/data/video/test.mp4
Binary file not shown.
59 changes: 56 additions & 3 deletions tests/unit/tools/test_video.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,63 @@
import tempfile
from typing import Optional

import cv2
import numpy as np

from vision_agent.utils.video import extract_frames_from_video


def test_extract_frames_from_video():
# TODO: consider generating a video on the fly instead
video_path = "tests/data/video/test.mp4"

video_path = _create_video(duration=2)
# there are 48 frames at 24 fps in this video file
res = extract_frames_from_video(video_path, fps=24)
assert len(res) == 48

res = extract_frames_from_video(video_path, fps=2)
assert len(res) == 4

res = extract_frames_from_video(video_path, fps=1)
assert len(res) == 2


def test_extract_frames_from_invalid_uri():
uri = "https://www.youtube.com/watch?v=HjGJvNRkuqY&ab_channel=TheSAHDStudio"
res = extract_frames_from_video(uri, 1.0)
assert len(res) == 0


def test_extract_frames_with_illegal_fps():
video_path = _create_video(duration=1)
res = extract_frames_from_video(video_path, -1.0)
assert len(res) == 1

res = extract_frames_from_video(video_path, None)
assert len(res) == 1

res = extract_frames_from_video(video_path, 0.0)
assert len(res) == 1


def test_extract_frames_with_input_video_has_no_fps():
video_path = _create_video(fps_video_prop=None)
res = extract_frames_from_video(video_path, 1.0)
assert len(res) == 0


def _create_video(
*, duration: int = 3, fps: int = 24, fps_video_prop: Optional[int] = 24
) -> str:
# Create a temporary file for the video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
video_path = temp_video.name
# Set video properties
width, height = 640, 480
# Create a VideoWriter object without setting FPS
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(video_path, fourcc, fps_video_prop, (width, height))
# Generate and write random frames
for _ in range(duration * fps):
frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
out.write(frame)
out.release()
return video_path
26 changes: 23 additions & 3 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
_LOGGER = logging.getLogger(__name__)
# The maximum length of the clip to extract frames from, in seconds

_DEFAULT_VIDEO_FPS = 24
_DEFAULT_INPUT_FPS = 1.0


def play_video(video_base64: str) -> None:
"""Play a video file"""
Expand Down Expand Up @@ -51,7 +54,9 @@ def _resize_frame(frame: np.ndarray) -> np.ndarray:


def video_writer(
frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None
frames: List[np.ndarray],
fps: float = _DEFAULT_INPUT_FPS,
filename: Optional[str] = None,
) -> str:
if filename is None:
filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
Expand All @@ -78,7 +83,7 @@ def video_writer(


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 1.0, file_ext: str = ".mp4"
frames: List[np.ndarray], fps: float = _DEFAULT_INPUT_FPS, file_ext: str = ".mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Expand All @@ -101,7 +106,7 @@ def frames_to_bytes(
# same file name and the time savings are very large.
@lru_cache(maxsize=8)
def extract_frames_from_video(
video_uri: str, fps: float = 1.0
video_uri: str, fps: float = _DEFAULT_INPUT_FPS
) -> List[Tuple[np.ndarray, float]]:
"""Extract frames from a video along with the timestamp in seconds.
Expand All @@ -118,6 +123,16 @@ def extract_frames_from_video(

cap = cv2.VideoCapture(video_uri)
orig_fps = cap.get(cv2.CAP_PROP_FPS)
if not orig_fps or orig_fps <= 0:
_LOGGER.warning(
f"Input video, {video_uri}, has no fps, using the default value {_DEFAULT_VIDEO_FPS}"
)
orig_fps = _DEFAULT_VIDEO_FPS
if not fps or fps <= 0:
_LOGGER.warning(
f"Input fps, {fps}, is illegal, using the default value: {_DEFAULT_INPUT_FPS}"
)
fps = _DEFAULT_INPUT_FPS
orig_frame_time = 1 / orig_fps
targ_frame_time = 1 / fps
frames: List[Tuple[np.ndarray, float]] = []
Expand All @@ -129,10 +144,15 @@ def extract_frames_from_video(
break

elapsed_time += orig_frame_time
# This is to prevent float point precision loss issue, which can cause
# the elapsed time to be slightly less than the target frame time, which
# causes the last frame to be skipped
elapsed_time = round(elapsed_time, 8)
if elapsed_time >= targ_frame_time:
frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps))
elapsed_time -= targ_frame_time

i += 1
cap.release()
_LOGGER.info(f"Extracted {len(frames)} frames from {video_uri}")
return frames

0 comments on commit f0a15ed

Please sign in to comment.