Skip to content

Commit

Permalink
added ixc 2.5 for video
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 21, 2024
1 parent e24762e commit 07832e3
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 50 deletions.
31 changes: 29 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import skimage as ski
from PIL import Image

from vision_agent.tools import (
blip_image_caption,
Expand All @@ -10,15 +11,17 @@
dpt_hybrid_midas,
florence2_image_caption,
florence2_object_detection,
florence2_roberta_vqa,
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
ixc25_image_vqa,
florence2_sam2_video,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand Down Expand Up @@ -101,6 +104,19 @@ def test_florence2_sam2_image():
assert len([res["mask"] for res in result]) == 25


def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = florence2_sam2_video(
prompt="coin",
frames=frames,
)
assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 25
assert len([res["mask"] for res in result[0]]) == 25


def test_segmentation():
img = ski.data.coins()
result = detr_segmentation(
Expand Down Expand Up @@ -197,6 +213,17 @@ def test_ixc25_image_vqa() -> None:
assert "cat" in result.strip()


def test_ixc25_video_vqa() -> None:
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
result = ixc25_video_vqa(
prompt="What animal is in this video?",
frames=frames,
)
assert "cat" in result.strip()


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
Expand Down
115 changes: 67 additions & 48 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,54 +355,6 @@ def florence2_sam2_video(
return return_data


def extract_frames(
video_uri: Union[str, Path], fps: float = 0.5
) -> List[Tuple[np.ndarray, float]]:
"""'extract_frames' extracts frames from a video which can be a file path or youtube
link, returns a list of tuples (frame, timestamp), where timestamp is the relative
time in seconds where the frame was captured. The frame is a numpy array.
Parameters:
video_uri (Union[str, Path]): The path to the video file or youtube link
fps (float, optional): The frame rate per second to extract the frames. Defaults
to 0.5.
Returns:
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
as a numpy array and the timestamp in seconds.
Example
-------
>>> extract_frames("path/to/video.mp4")
[(frame1, 0.0), (frame2, 0.5), ...]
"""

if str(video_uri).startswith(
(
"http://www.youtube.com/",
"https://www.youtube.com/",
"http://youtu.be/",
"https://youtu.be/",
)
):
with tempfile.TemporaryDirectory() as temp_dir:
yt = YouTube(str(video_uri))
# Download the highest resolution video
video = (
yt.streams.filter(progressive=True, file_extension="mp4")
.order_by("resolution")
.desc()
.first()
)
if not video:
raise Exception("No suitable video stream found")
video_file_path = video.download(output_path=temp_dir)

return extract_frames_from_video(video_file_path, fps)

return extract_frames_from_video(str(video_uri), fps)


def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
"""'ocr' extracts text from an image. It returns a list of detected text, bounding
boxes with normalized coordinates, and confidence scores. The results are sorted
Expand Down Expand Up @@ -580,6 +532,25 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
return data["answer"]


def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
"""'ixc25_video_vqa' is a tool that can answer any questions about arbitrary videos
including regular videos or videos of documents or presentations. It returns text
as an answer to the question.
"""
buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {
"prompt": prompt,
"function_name": "ixc25_video_vqa",
}
data: Dict[str, Any] = send_inference_request(
payload, "internlm-xcomposer2", files=files, v2=True
)
return data["answer"]


def git_vqa_v2(prompt: str, image: np.ndarray) -> str:
"""'git_vqa_v2' is a tool that can answer questions about the visual
contents of an image given a question and an image. It returns an answer to the
Expand Down Expand Up @@ -1166,6 +1137,54 @@ def closest_box_distance(
# Utility and visualization functions


def extract_frames(
video_uri: Union[str, Path], fps: float = 0.5
) -> List[Tuple[np.ndarray, float]]:
"""'extract_frames' extracts frames from a video which can be a file path or youtube
link, returns a list of tuples (frame, timestamp), where timestamp is the relative
time in seconds where the frame was captured. The frame is a numpy array.
Parameters:
video_uri (Union[str, Path]): The path to the video file or youtube link
fps (float, optional): The frame rate per second to extract the frames. Defaults
to 0.5.
Returns:
List[Tuple[np.ndarray, float]]: A list of tuples containing the extracted frame
as a numpy array and the timestamp in seconds.
Example
-------
>>> extract_frames("path/to/video.mp4")
[(frame1, 0.0), (frame2, 0.5), ...]
"""

if str(video_uri).startswith(
(
"http://www.youtube.com/",
"https://www.youtube.com/",
"http://youtu.be/",
"https://youtu.be/",
)
):
with tempfile.TemporaryDirectory() as temp_dir:
yt = YouTube(str(video_uri))
# Download the highest resolution video
video = (
yt.streams.filter(progressive=True, file_extension="mp4")
.order_by("resolution")
.desc()
.first()
)
if not video:
raise Exception("No suitable video stream found")
video_file_path = video.download(output_path=temp_dir)

return extract_frames_from_video(video_file_path, fps)

return extract_frames_from_video(str(video_uri), fps)


def save_json(data: Any, file_path: str) -> None:
"""'save_json' is a utility function that saves data as a JSON file. It is helpful
for saving data that contains NumPy arrays which are not JSON serializable.
Expand Down

0 comments on commit 07832e3

Please sign in to comment.