Skip to content

Commit

Permalink
Add frame extraction tool for video processing (#26)
Browse files Browse the repository at this point in the history
* Add frame extraction tool for video processing

* Update ExtractFrames tool

* minor updates

* Update docs

* Attempt to support py 3.9

* Update pyproject.toml

* Test 3.9 and 3.11 in CI

---------

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
humpydonkey and AsiaCao authored Mar 27, 2024
1 parent 8a44672 commit c8d9620
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
Test:
strategy:
matrix:
python-version: [3.10.11]
python-version: [3.9, 3.11]
os: [ ubuntu-22.04, windows-2022, macos-12 ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ you. For example:
| Crop | Crop crops an image given a bounding box and returns a file name of the cropped image. |
| BboxArea | BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places. |
| SegArea | SegArea returns the area of the segmentation mask in pixels normalized to 2 decimal places. |
| ExtractFrames | ExtractFrames extracts image frames from the input video. |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
4 changes: 3 additions & 1 deletion docs/api/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

::: vision_agent.tools.prompts

::: vision_agent.tools.tools
::: vision_agent.tools.tools

::: vision_agent.tools.video
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ you. For example:
| Crop | Crop crops an image given a bounding box and returns a file name of the cropped image. |
| BboxArea | BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places. |
| SegArea | SegArea returns the area of the segmentation mask in pixels normalized to 2 decimal places. |
| ExtractFrames | ExtractFrames extracts image frames from the input video. |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
126 changes: 124 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ packages = [{include = "vision_agent"}]
"documentation" = "https://github.com/landing-ai/vision-agent"

[tool.poetry.dependencies] # main dependency group
python = ">=3.10,<3.12"
python = ">=3.9,<3.12"

numpy = ">=1.21.0,<2.0.0"
pillow = "10.*"
Expand All @@ -28,6 +28,8 @@ torch = "2.1.*" # 2.2 causes sentence-transformers to seg fault
sentence-transformers = "2.*"
openai = "1.*"
typing_extensions = "4.*"
moviepy = "1.*"
opencv-python-headless = "4.*"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down Expand Up @@ -84,4 +86,5 @@ module = [
"faiss.*",
"openai.*",
"sentence_transformers.*",
"moviepy.*",
]
Binary file added tests/data/video/test.mp4
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/tools/test_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from vision_agent.tools.video import extract_frames_from_video


def test_extract_frames_from_video():
# TODO: consider generating a video on the fly instead
video_path = "tests/data/video/test.mp4"
res = extract_frames_from_video(video_path)
assert len(res) == 1
43 changes: 43 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64, get_image_size
from vision_agent.tools.video import extract_frames_from_video

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -505,6 +506,47 @@ def __call__(self, input: List[int]) -> float:
return round(input[0] / input[1], 2)


class ExtractFrames(Tool):
r"""Extract frames from a video."""

name = "extract_frames_"
description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
usage = {
"required_parameters": [{"name": "video_uri", "type": "str"}],
"examples": [
{
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
},
{
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4"},
},
],
}

def __call__(self, video_uri: str) -> list[tuple[str, float]]:
"""Extract frames from a video.
Parameters:
video_uri: the path to the video file or a url points to the video data
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
"""
frames = extract_frames_from_video(video_uri)
result = []
_LOGGER.info(
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
)
for frame, ts in frames:
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.fromarray(frame).save(tmp)
result.append((tmp.name, ts))
return result


TOOLS = {
i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
for i, c in enumerate(
Expand All @@ -520,6 +562,7 @@ def __call__(self, input: List[int]) -> float:
Subtract,
Multiply,
Divide,
ExtractFrames,
]
)
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
Expand Down
Loading

0 comments on commit c8d9620

Please sign in to comment.