diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index f0a6bc29..7840ba40 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -12,7 +12,6 @@ import cv2 import numpy as np import requests -from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageEnhance, ImageFont from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore @@ -35,7 +34,6 @@ ODResponseData, PromptTask, ) -from vision_agent.utils import extract_frames_from_video from vision_agent.utils.exceptions import FineTuneModelIsNotReady from vision_agent.utils.execute import FileSerializer, MimeType from vision_agent.utils.image_utils import ( @@ -44,13 +42,17 @@ convert_to_b64, denormalize_bbox, encode_image_bytes, - frames_to_bytes, get_image_size, normalize_bbox, numpy_to_bytes, rle_decode, rle_decode_array, ) +from vision_agent.utils.video import ( + extract_frames_from_video, + frames_to_bytes, + video_writer, +) register_heif_opener() @@ -1513,17 +1515,14 @@ def save_video( "/tmp/tmpvideo123.mp4" """ if fps <= 0: - _LOGGER.warning(f"Invalid fps value: {fps}. Setting fps to 4 (default value).") - fps = 4 - with ImageSequenceClip(frames, fps=fps) as video: - if output_video_path: - f = open(output_video_path, "wb") - else: - f = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) # type: ignore - video.write_videofile(f.name, codec="libx264") - f.close() - _save_video_to_result(f.name) - return f.name + raise ValueError(f"fps must be greater than 0 got {fps}") + + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(delete=False).name + + output_video_path = video_writer(frames, fps, output_video_path) + _save_video_to_result(output_video_path) + return output_video_path def _save_video_to_result(video_uri: str) -> None: diff --git a/vision_agent/utils/__init__.py b/vision_agent/utils/__init__.py index 9a5a271a..2810713a 100644 --- a/vision_agent/utils/__init__.py +++ b/vision_agent/utils/__init__.py @@ -7,4 +7,4 @@ Result, ) from .sim import AzureSim, OllamaSim, Sim, load_sim, merge_sim -from .video import extract_frames_from_video +from .video import extract_frames_from_video, video_writer diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index 3612592d..d228b963 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -9,11 +9,10 @@ from typing import Dict, List, Tuple, Union import numpy as np -from moviepy.editor import ImageSequenceClip from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType -from vision_agent.utils import extract_frames_from_video +from vision_agent.utils import extract_frames_from_video, video_writer COLORS = [ (158, 218, 229), @@ -90,24 +89,6 @@ def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray: return binary_mask -def frames_to_bytes( - frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4" -) -> bytes: - r"""Convert a list of frames to a video file encoded into a byte string. - - Parameters: - frames: the list of frames - fps: the frames per second of the video - file_ext: the file extension of the video file - """ - with tempfile.NamedTemporaryFile(delete=True) as temp_file: - clip = ImageSequenceClip(frames, fps=fps) - clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps, codec="libx264") - with open(temp_file.name + f".{file_ext}", "rb") as f: - buffer_bytes = f.read() - return buffer_bytes - - def b64_to_pil(b64_str: str) -> ImageType: r"""Convert a base64 string to a PIL Image. diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index 64bb28af..17aed535 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -2,7 +2,7 @@ import logging import tempfile from functools import lru_cache -from typing import List, Tuple +from typing import List, Optional, Tuple import cv2 import numpy as np @@ -43,6 +43,39 @@ def play_video(video_base64: str) -> None: cv2.destroyAllWindows() +def video_writer( + frames: List[np.ndarray], fps: float = 1.0, filename: Optional[str] = None +) -> str: + if filename is None: + filename = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name + + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + height, width = frames[0].shape[:2] + writer = cv2.VideoWriter(filename, fourcc, fps, (width, height)) + for frame in frames: + writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + writer.release() + return filename + + +def frames_to_bytes( + frames: List[np.ndarray], fps: float = 10, file_ext: str = ".mp4" +) -> bytes: + r"""Convert a list of frames to a video file encoded into a byte string. + + Parameters: + frames: the list of frames + fps: the frames per second of the video + file_ext: the file extension of the video file + """ + with tempfile.NamedTemporaryFile(delete=True, suffix=file_ext) as temp_file: + video_writer(frames, fps, temp_file.name) + + with open(temp_file.name, "rb") as f: + buffer_bytes = f.read() + return buffer_bytes + + @lru_cache(maxsize=8) def extract_frames_from_video( video_uri: str, fps: float = 1.0