diff --git a/tests/data/video/test.mp4 b/tests/data/video/test.mp4 deleted file mode 100644 index 596eea38..00000000 Binary files a/tests/data/video/test.mp4 and /dev/null differ diff --git a/tests/unit/tools/test_video.py b/tests/unit/tools/test_video.py index 2ef1fe21..81b1699e 100644 --- a/tests/unit/tools/test_video.py +++ b/tests/unit/tools/test_video.py @@ -1,10 +1,63 @@ +import tempfile +from typing import Optional + +import cv2 +import numpy as np + from vision_agent.utils.video import extract_frames_from_video def test_extract_frames_from_video(): - # TODO: consider generating a video on the fly instead - video_path = "tests/data/video/test.mp4" - + video_path = _create_video(duration=2) # there are 48 frames at 24 fps in this video file res = extract_frames_from_video(video_path, fps=24) assert len(res) == 48 + + res = extract_frames_from_video(video_path, fps=2) + assert len(res) == 4 + + res = extract_frames_from_video(video_path, fps=1) + assert len(res) == 2 + + +def test_extract_frames_from_invalid_uri(): + uri = "https://www.youtube.com/watch?v=HjGJvNRkuqY&ab_channel=TheSAHDStudio" + res = extract_frames_from_video(uri, 1.0) + assert len(res) == 0 + + +def test_extract_frames_with_illegal_fps(): + video_path = _create_video(duration=1) + res = extract_frames_from_video(video_path, -1.0) + assert len(res) == 1 + + res = extract_frames_from_video(video_path, None) + assert len(res) == 1 + + res = extract_frames_from_video(video_path, 0.0) + assert len(res) == 1 + + +def test_extract_frames_with_input_video_has_no_fps(): + video_path = _create_video(fps_video_prop=None) + res = extract_frames_from_video(video_path, 1.0) + assert len(res) == 0 + + +def _create_video( + *, duration: int = 3, fps: int = 24, fps_video_prop: Optional[int] = 24 +) -> str: + # Create a temporary file for the video + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video: + video_path = temp_video.name + # Set video properties + width, height = 640, 480 + # Create a VideoWriter object without setting FPS + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(video_path, fourcc, fps_video_prop, (width, height)) + # Generate and write random frames + for _ in range(duration * fps): + frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + out.write(frame) + out.release() + return video_path diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index 0bb6fb18..6e8605c2 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -1,5 +1,6 @@ import base64 import logging +import math import tempfile from functools import lru_cache from typing import List, Optional, Tuple @@ -11,6 +12,9 @@ _LOGGER = logging.getLogger(__name__) # The maximum length of the clip to extract frames from, in seconds +_DEFAULT_VIDEO_FPS = 24 +_DEFAULT_INPUT_FPS = 1.0 + def play_video(video_base64: str) -> None: """Play a video file""" @@ -51,7 +55,9 @@ def _resize_frame(frame: np.ndarray) -> np.ndarray: def video_writer( - frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None + frames: List[np.ndarray], + fps: float = _DEFAULT_INPUT_FPS, + filename: Optional[str] = None, ) -> str: if filename is None: filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name @@ -78,7 +84,7 @@ def video_writer( def frames_to_bytes( - frames: List[np.ndarray], fps: float = 1.0, file_ext: str = ".mp4" + frames: List[np.ndarray], fps: float = _DEFAULT_INPUT_FPS, file_ext: str = ".mp4" ) -> bytes: r"""Convert a list of frames to a video file encoded into a byte string. @@ -101,7 +107,7 @@ def frames_to_bytes( # same file name and the time savings are very large. @lru_cache(maxsize=8) def extract_frames_from_video( - video_uri: str, fps: float = 1.0 + video_uri: str, fps: float = _DEFAULT_INPUT_FPS ) -> List[Tuple[np.ndarray, float]]: """Extract frames from a video along with the timestamp in seconds. @@ -118,6 +124,16 @@ def extract_frames_from_video( cap = cv2.VideoCapture(video_uri) orig_fps = cap.get(cv2.CAP_PROP_FPS) + if not orig_fps or orig_fps <= 0: + _LOGGER.warning( + f"Input video, {video_uri}, has no fps, using the default value {_DEFAULT_VIDEO_FPS}" + ) + orig_fps = _DEFAULT_VIDEO_FPS + if not fps or fps <= 0: + _LOGGER.warning( + f"Input fps, {fps}, is illegal, using the default value: {_DEFAULT_INPUT_FPS}" + ) + fps = _DEFAULT_INPUT_FPS orig_frame_time = 1 / orig_fps targ_frame_time = 1 / fps frames: List[Tuple[np.ndarray, float]] = [] @@ -129,10 +145,15 @@ def extract_frames_from_video( break elapsed_time += orig_frame_time + # This is to prevent float point precision loss issue, which can cause + # the elapsed time to be slightly less than the target frame time, which + # causes the last frame to be skipped + elapsed_time = round(elapsed_time, 8) if elapsed_time >= targ_frame_time: frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps)) elapsed_time -= targ_frame_time i += 1 cap.release() + _LOGGER.info(f"Extracted {len(frames)} frames from {video_uri}") return frames