Skip to content

Commit

Permalink
Add Temporal Localization & Fix Video Reading (#233)
Browse files Browse the repository at this point in the history
* fixed issue with video reader

* added temporal localization

* fix video reader

* remove decord
  • Loading branch information
dillonalaird authored Sep 11, 2024
1 parent 28c7787 commit 7db696d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 68 deletions.
92 changes: 35 additions & 57 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ pillow-heif = "^0.16.0"
pytube = "15.0.0"
anthropic = "^0.31.0"
pydantic = "2.7.4"
eva-decord = "^0.6.1"
av = "^11.0.0"

[tool.poetry.group.dev.dependencies]
Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_temporal_localization,
ixc25_video_vqa,
load_image,
loca_visual_prompt_counting,
Expand Down
38 changes: 38 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,44 @@ def ixc25_video_vqa(prompt: str, frames: List[np.ndarray]) -> str:
return cast(str, data["answer"])


def ixc25_temporal_localization(prompt: str, frames: List[np.ndarray]) -> List[bool]:
"""'ixc25_temporal_localization' uses ixc25_video_vqa to temporally segment a video
given a prompt that can be other an object or a phrase. It returns a list of
boolean values indicating whether the object or phrase is present in the
corresponding frame.
Parameters:
prompt (str): The question about the video
frames (List[np.ndarray]): The reference frames used for the question
Returns:
List[bool]: A list of boolean values indicating whether the object or phrase is
present in the corresponding frame.
Example
-------
>>> output = ixc25_temporal_localization('soccer goal', frames)
>>> print(output)
[False, False, False, True, True, True, False, False, False, False]
>>> save_video([f for i, f in enumerate(frames) if output[i]], 'output.mp4')
"""

buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {
"prompt": prompt,
"chunk_length": 2,
"function_name": "ixc25_temporal_localization",
}
data: List[int] = send_inference_request(
payload, "video-temporal-localization", files=files, v2=True
)
chunk_size = round(len(frames) / len(data))
data_explode = [[elt] * chunk_size for elt in data]
data_bool = [bool(elt) for sublist in data_explode for elt in sublist]
return data_bool[: len(frames)]


def gpt4o_image_vqa(prompt: str, image: np.ndarray) -> str:
"""'gpt4o_image_vqa' is a tool that can answer any questions about arbitrary images
including regular images or images of documents or presentations. It returns text
Expand Down
32 changes: 22 additions & 10 deletions vision_agent/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def frames_to_bytes(
def extract_frames_from_video(
video_uri: str, fps: float = 1.0
) -> List[Tuple[np.ndarray, float]]:
"""Extract frames from a video
"""Extract frames from a video along with the timestamp in seconds.
Parameters:
video_uri (str): the path to the video file or a video file url
Expand All @@ -115,12 +115,24 @@ def extract_frames_from_video(
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
the video. The frames are sorted by the timestamp in ascending order.
"""
vr = VideoReader(video_uri)
orig_fps = vr.get_avg_fps()
if fps > orig_fps:
fps = orig_fps

s = orig_fps / fps
samples = [(int(i * s), int(i * s) / orig_fps) for i in range(int(len(vr) / s))]
frames = vr.get_batch([s[0] for s in samples]).asnumpy()
return [(frames[i, :, :, :], samples[i][1]) for i in range(len(samples))]

cap = cv2.VideoCapture(video_uri)
orig_fps = cap.get(cv2.CAP_PROP_FPS)
orig_frame_time = 1 / orig_fps
targ_frame_time = 1 / fps
frames = []
i = 0
elapsed_time = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

elapsed_time += orig_frame_time
if elapsed_time >= targ_frame_time:
frames.append((cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), i / orig_fps))
elapsed_time -= targ_frame_time

i += 1
cap.release()
return frames

0 comments on commit 7db696d

Please sign in to comment.