Skip to content

Commit

Permalink
Update video reader (#226)
Browse files Browse the repository at this point in the history
* updated video reader

* updated dependencies

* organized video writing

* flake8

* fix type error

* fix type error

* fix test case

* add warning comment

* fix print
  • Loading branch information
dillonalaird authored Sep 6, 2024
1 parent 8ebb01e commit b082c01
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 386 deletions.
336 changes: 151 additions & 185 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ tqdm = ">=4.64.0,<5.0.0"
pandas = "2.*"
openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
opencv-python = "4.*"
tabulate = "^0.9.0"
pydantic-settings = "^2.2.1"
Expand All @@ -42,6 +41,7 @@ pillow-heif = "^0.16.0"
pytube = "15.0.0"
anthropic = "^0.31.0"
pydantic = "2.7.4"
eva-decord = "^0.6.1"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down Expand Up @@ -100,10 +100,8 @@ show_error_codes = true
ignore_missing_imports = true
module = [
"cv2.*",
"faiss.*",
"openai.*",
"sentence_transformers.*",
"moviepy.*",
"e2b_code_interpreter.*",
"e2b.*"
]
4 changes: 3 additions & 1 deletion tests/unit/tools/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
def test_extract_frames_from_video():
# TODO: consider generating a video on the fly instead
video_path = "tests/data/video/test.mp4"

# there are 48 frames at 24 fps in this video file
res = extract_frames_from_video(video_path)
assert len(res) == 1
assert len(res) == 2
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def pick_plan(

if verbosity == 2:
_print_code("Initial code and tests:", code)
_LOGGER.info(f"Initial code execution result:\n{tool_output.text()}")
_LOGGER.info(f"Initial code execution result:\n{tool_output_str}")

log_progress(
{
Expand Down
27 changes: 13 additions & 14 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import cv2
import numpy as np
import requests
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore
Expand All @@ -35,7 +34,6 @@
ODResponseData,
PromptTask,
)
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.exceptions import FineTuneModelIsNotReady
from vision_agent.utils.execute import FileSerializer, MimeType
from vision_agent.utils.image_utils import (
Expand All @@ -44,13 +42,17 @@
convert_to_b64,
denormalize_bbox,
encode_image_bytes,
frames_to_bytes,
get_image_size,
normalize_bbox,
numpy_to_bytes,
rle_decode,
rle_decode_array,
)
from vision_agent.utils.video import (
extract_frames_from_video,
frames_to_bytes,
video_writer,
)

register_heif_opener()

Expand Down Expand Up @@ -1513,17 +1515,14 @@ def save_video(
"/tmp/tmpvideo123.mp4"
"""
if fps <= 0:
_LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).")
fps = 4
with ImageSequenceClip(frames, fps=fps) as video:
if output_video_path:
f = open(output_video_path, "wb")
else:
f = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) # type: ignore
video.write_videofile(f.name, codec="libx264")
f.close()
_save_video_to_result(f.name)
return f.name
raise ValueError(f"fps must be greater than 0 got {fps}")

if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(delete=False).name

output_video_path = video_writer(frames, fps, output_video_path)
_save_video_to_result(output_video_path)
return output_video_path


def _save_video_to_result(video_uri: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
Result,
)
from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim
from .video import extract_frames_from_video
from .video import extract_frames_from_video, video_writer
20 changes: 0 additions & 20 deletions vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import base64
import io
import tempfile
from importlib import resources
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

Expand Down Expand Up @@ -90,24 +88,6 @@ def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
return binary_mask


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Parameters:
frames: the list of frames
fps: the frames per second of the video
file_ext: the file extension of the video file
"""
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
clip = ImageSequenceClip(frames, fps=fps)
clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps, codec="libx264")
with open(temp_file.name + f".{file_ext}", "rb") as f:
buffer_bytes = f.read()
return buffer_bytes


def b64_to_pil(b64_str: str) -> ImageType:
r"""Convert a base64 string to a PIL Image.
Expand Down
214 changes: 53 additions & 161 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import base64
import logging
import math
import os
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Tuple, cast
from functools import lru_cache
from typing import List, Optional, Tuple

import cv2
import numpy as np
from moviepy.video.io.VideoFileClip import VideoFileClip
from tqdm import tqdm
from decord import VideoReader # type: ignore

_LOGGER = logging.getLogger(__name__)
# The maximum length of the clip to extract frames from, in seconds
_CLIP_LENGTH = 30.0


def play_video(video_base64: str) -> None:
Expand Down Expand Up @@ -47,169 +43,65 @@ def play_video(video_base64: str) -> None:
cv2.destroyAllWindows()


def video_writer(
frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None
) -> str:
if filename is None:
filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name

fourcc = cv2.VideoWriter_fourcc(*"mp4v") # type: ignore
height, width = frames[0].shape[:2]
writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
return filename


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 10, file_ext: str = ".mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Parameters:
frames: the list of frames
fps: the frames per second of the video
file_ext: the file extension of the video file
"""
with tempfile.NamedTemporaryFile(delete=True, suffix=file_ext) as temp_file:
video_writer(frames, fps, temp_file.name)

with open(temp_file.name, "rb") as f:
buffer_bytes = f.read()
return buffer_bytes


# WARNING: this cache is cache is a little dangerous because if the underlying video
# contents change but the filename remains the same it will return the old file contents
# but for vision agent it's unlikely to change the file contents while keeping the
# same file name and the time savings are very large.
@lru_cache(maxsize=8)
def extract_frames_from_video(
video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0
video_uri: str, fps: float = 1.0
) -> List[Tuple[np.ndarray, float]]:
"""Extract frames from a video
Parameters:
video_uri: the path to the video file or a video file url
fps: the frame rate per second to extract the frames
motion_detection_threshold: The threshold to detect motion between
changes/frames. A value between 0-1, which represents the percentage change
required for the frames to be considered in motion. For example, a lower
value means more frames will be extracted. A non-positive value will disable
motion detection and extract all frames.
video_uri (str): the path to the video file or a video file url
fps (float): the frame rate per second to extract the frames
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds.
E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
the video. The frames are sorted by the timestamp in ascending order.
"""
with VideoFileClip(video_uri) as video:
video_duration: float = video.duration
num_workers = os.cpu_count()
clip_length: float = min(video_duration, _CLIP_LENGTH)
start_times = list(range(0, math.ceil(video_duration), math.ceil(clip_length)))
assert start_times, f"No frames to extract from the input video: {video_uri}"
segment_args = [
{
"video_uri": video_uri,
"start": start,
"end": (
start + clip_length if i < len(start_times) - 1 else video_duration
),
"fps": fps,
"motion_detection_threshold": motion_detection_threshold,
}
for i, start in enumerate(start_times)
]
if (
cast(float, segment_args[-1]["end"])
- cast(float, segment_args[-1]["start"])
< 1
):
# If the last segment is less than 1s, merge it with the previous segment
# This is to avoid the failure of the last segment extraction
assert (
len(segment_args) > 1
), "Development bug - Expect at least 2 segments."
segment_args[-2]["end"] = video_duration
segment_args.pop(-1)
_LOGGER.info(
f"""Created {len(segment_args)} segments from the input video {video_uri} of length {video.duration}s, with clip size: {clip_length}s and {num_workers} workers.
Segments: {segment_args}
"""
)
frames = []
with tqdm(total=len(segment_args)) as pbar:
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [
executor.submit(_extract_frames_by_clip, **kwargs) # type: ignore
for kwargs in segment_args
]
for future in as_completed(futures):
result = future.result()
frames.extend(result)
pbar.update(1)
frames.sort(key=lambda x: x[1])
_LOGGER.info(f"Extracted {len(frames)} frames from video {video_uri}")
return frames


def _extract_frames_by_clip(
video_uri: str,
start: int = 0,
end: float = -1,
fps: int = 2,
motion_detection_threshold: float = 0.06,
) -> List[Tuple[np.ndarray, float]]:
"""Extract frames from a video clip with start and end time in seconds.
Parameters:
video_uri: the path to the video file or a video file url
start: the start time (in seconds) of the clip to extract
end: the end time (in seconds, up to millisecond level precision) of the clip to extract, if -1, extract the whole video
fps: the frame rate to extract the frames
motion_detection_threshold: the threshold to detect the motion between frames
"""
with VideoFileClip(video_uri) as video:
source_fps = video.fps
if end <= 0:
end = video.duration
_LOGGER.info(
f"Extracting frames from video {video_uri} ({video.duration}s) with start={start}s and end={end}s"
)
clip = video.subclip(start, end)
processable_frames = int(clip.duration * fps)
_LOGGER.info(
f"Extracting frames from video clip of length {clip.duration}s with FPS={fps} and start_time={start}s. Total number of frames in clip: {processable_frames}"
)
frames = []
total_count, skipped_count = 0, 0
prev_processed_frame = None
pbar = tqdm(
total=processable_frames, desc=f"Extracting frames from clip {start}-{end}"
)
for i, frame in enumerate(clip.iter_frames(fps=fps, dtype="uint8")):
total_count += 1
pbar.update(1)
if motion_detection_threshold > 0:
curr_processed_frame = _preprocess_frame(frame)
# Skip the frame if it is similar to the previous one
if prev_processed_frame is not None and _similar_frame(
prev_processed_frame,
curr_processed_frame,
threshold=motion_detection_threshold,
):
skipped_count += 1
continue
prev_processed_frame = curr_processed_frame
ts = round(clip.reader.pos / source_fps, 3)
frames.append((frame, ts))

_LOGGER.info(
f"""Finished!
Frames extracted: {len(frames)}
Extracted frame timestamp: {[f[1] for f in frames]}
Total processed frames: {total_count}
Skipped frames: {skipped_count}
Scan FPS: {fps}
Clip start time: {start}s, {clip.pos}
Clip end time: {end}s
Clip duration: {clip.duration}s
Clip total frames: {clip.duration * source_fps}
Video duration: {video.duration}s
Video FPS: {video.fps}
Video total frames: {video.reader.nframes}"""
)
return frames


def _preprocess_frame(frame: np.ndarray) -> np.ndarray:
# Convert to grayscale
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frame = cv2.GaussianBlur(src=frame, ksize=(5, 5), sigmaX=0)
return frame


def _similar_frame(
prev_frame: np.ndarray, curr_frame: np.ndarray, threshold: float
) -> bool:
"""Detect two frames are similar or not
Parameters:
threshold: similarity threshold, a value between 0-1, the percentage change that is considered a different frame.
"""
# calculate difference and update previous frame TODO: don't assume the processed image is cached
diff_frame = cv2.absdiff(src1=prev_frame, src2=curr_frame)
# Only take different areas that are different enough (>20 / 255)
thresh_frame = cv2.threshold(
src=diff_frame, thresh=20, maxval=255, type=cv2.THRESH_BINARY
)[1]
change_percentage = cv2.countNonZero(thresh_frame) / (
curr_frame.shape[0] * curr_frame.shape[1]
)
_LOGGER.debug(f"Image diff: {change_percentage}")
return change_percentage < threshold
vr = VideoReader(video_uri)
orig_fps = vr.get_avg_fps()
if fps > orig_fps:
fps = orig_fps

s = orig_fps / fps
samples = [(int(i * s), int(i * s) / orig_fps) for i in range(int(len(vr) / s))]
frames = vr.get_batch([s[0] for s in samples]).asnumpy()
return [(frames[i, :, :, :], samples[i][1]) for i in range(len(samples))]

0 comments on commit b082c01

Please sign in to comment.