From 63746e8039baff61b02fd4b693ec90bfedca21fa Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 18 Aug 2024 20:04:11 -0700 Subject: [PATCH] added ixc 2.5 for video --- tests/integ/test_tools.py | 31 ++++++++- vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 115 +++++++++++++++++++-------------- 3 files changed, 97 insertions(+), 50 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 0a52a0b2..afa9dcb4 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -1,5 +1,6 @@ import numpy as np import skimage as ski +from PIL import Image from vision_agent.tools import ( blip_image_caption, @@ -10,15 +11,17 @@ dpt_hybrid_midas, florence2_image_caption, florence2_object_detection, - florence2_roberta_vqa, florence2_ocr, + florence2_roberta_vqa, florence2_sam2_image, - ixc25_image_vqa, + florence2_sam2_video, generate_pose_image, generate_soft_edge_image, git_vqa_v2, grounding_dino, grounding_sam, + ixc25_image_vqa, + ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, @@ -101,6 +104,19 @@ def test_florence2_sam2_image(): assert len([res["mask"] for res in result]) == 25 +def test_florence2_sam2_video(): + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = florence2_sam2_video( + prompt="coin", + frames=frames, + ) + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 25 + assert len([res["mask"] for res in result[0]]) == 25 + + def test_segmentation(): img = ski.data.coins() result = detr_segmentation( @@ -197,6 +213,17 @@ def test_ixc25_image_vqa() -> None: assert "cat" in result.strip() +def test_ixc25_video_vqa() -> None: + frames = [ + np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10) + ] + result = ixc25_video_vqa( + prompt="What animal is in this video?", + frames=frames, + ) + assert "cat" in result.strip() + + def test_ocr() -> None: img = ski.data.page() result = ocr( diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 3369499d..d5da4ad8 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -29,6 +29,7 @@ grounding_dino, grounding_sam, ixc25_image_vqa, + ixc25_video_vqa, load_image, loca_visual_prompt_counting, loca_zero_shot_counting, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 8c2e8d7e..e21dd95c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -355,54 +355,6 @@ def florence2_sam2_video( return return_data -def extract_frames( - video_uri: Union[str, Path], fps: float = 0.5 -) -> List[Tuple[np.ndarray, float]]: - """'extract_frames' extracts frames from a video which can be a file path or youtube - link, returns a list of tuples (frame, timestamp), where timestamp is the relative - time in seconds where the frame was captured. The frame is a numpy array. - - Parameters: - video_uri (Union[str, Path]): The path to the video file or youtube link - fps (float, optional): The frame rate per second to extract the frames. Defaults - to 0.5. - - Returns: - List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame - as a numpy array and the timestamp in seconds. - - Example - ------- - >>> extract_frames("path/to/video.mp4") - [(frame1, 0.0), (frame2, 0.5), ...] - """ - - if str(video_uri).startswith( - ( - "http://www.youtube.com/", - "https://www.youtube.com/", - "http://youtu.be/", - "https://youtu.be/", - ) - ): - with tempfile.TemporaryDirectory() as temp_dir: - yt = YouTube(str(video_uri)) - # Download the highest resolution video - video = ( - yt.streams.filter(progressive=True, file_extension="mp4") - .order_by("resolution") - .desc() - .first() - ) - if not video: - raise Exception("No suitable video stream found") - video_file_path = video.download(output_path=temp_dir) - - return extract_frames_from_video(video_file_path, fps) - - return extract_frames_from_video(str(video_uri), fps) - - def ocr(image: np.ndarray) -> List[Dict[str, Any]]: """'ocr' extracts text from an image. It returns a list of detected text, bounding boxes with normalized coordinates, and confidence scores. The results are sorted @@ -572,6 +524,25 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str: return data["answer"] +def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str: + """'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos + including regular videos or videos of documents or presentations. It returns text + as an answer to the question. + + + """ + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "prompt": prompt, + "function_name": "ixc25_video_vqa", + } + data: Dict[str, Any] = send_inference_request( + payload, "internlm-xcomposer2", files=files, v2=True + ) + return data["answer"] + + def git_vqa_v2(prompt: str, image: np.ndarray) -> str: """'git_vqa_v2' is a tool that can answer questions about the visual contents of an image given a question and an image. It returns an answer to the @@ -1158,6 +1129,54 @@ def closest_box_distance( # Utility and visualization functions +def extract_frames( + video_uri: Union[str, Path], fps: float = 0.5 +) -> List[Tuple[np.ndarray, float]]: + """'extract_frames' extracts frames from a video which can be a file path or youtube + link, returns a list of tuples (frame, timestamp), where timestamp is the relative + time in seconds where the frame was captured. The frame is a numpy array. + + Parameters: + video_uri (Union[str, Path]): The path to the video file or youtube link + fps (float, optional): The frame rate per second to extract the frames. Defaults + to 0.5. + + Returns: + List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame + as a numpy array and the timestamp in seconds. + + Example + ------- + >>> extract_frames("path/to/video.mp4") + [(frame1, 0.0), (frame2, 0.5), ...] + """ + + if str(video_uri).startswith( + ( + "http://www.youtube.com/", + "https://www.youtube.com/", + "http://youtu.be/", + "https://youtu.be/", + ) + ): + with tempfile.TemporaryDirectory() as temp_dir: + yt = YouTube(str(video_uri)) + # Download the highest resolution video + video = ( + yt.streams.filter(progressive=True, file_extension="mp4") + .order_by("resolution") + .desc() + .first() + ) + if not video: + raise Exception("No suitable video stream found") + video_file_path = video.download(output_path=temp_dir) + + return extract_frames_from_video(video_file_path, fps) + + return extract_frames_from_video(str(video_uri), fps) + + def save_json(data: Any, file_path: str) -> None: """'save_json' is a utility function that saves data as a JSON file. It is helpful for saving data that contains NumPy arrays which are not JSON serializable.