From c438d9be72fbec14d78be502b456865c0041e3d5 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 6 Sep 2024 16:57:39 -0700 Subject: [PATCH] Improve video usage (#229) * add owlv2 video * update doc extract_frames to include urls * fix countgd return decimal places * fixed return types * prompt tests to run faster * testing owlv2_video * updated name to florence2_sam2_video_tracking * lowered threshold * ran isort * fix mypy errors * fix tests' * fix tests --- tests/integ/test_tools.py | 30 +++- tests/integration_dev/test_tools.py | 5 +- .../agent/vision_agent_coder_prompts.py | 48 +++++- vision_agent/tools/__init__.py | 5 +- vision_agent/tools/tools.py | 163 ++++++++++++++---- vision_agent/utils/video.py | 2 +- 6 files changed, 193 insertions(+), 60 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index bca1f6ea..24bd259f 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -10,11 +10,11 @@ detr_segmentation, dpt_hybrid_midas, florence2_image_caption, - florence2_phrase_grounding, florence2_ocr, + florence2_phrase_grounding, florence2_roberta_vqa, florence2_sam2_image, - florence2_sam2_video, + florence2_sam2_video_tracking, generate_pose_image, generate_soft_edge_image, git_vqa_v2, @@ -25,7 +25,8 @@ loca_visual_prompt_counting, loca_zero_shot_counting, ocr, - owl_v2, + owl_v2_image, + owl_v2_video, template_match, vit_image_classification, vit_nsfw_classification, @@ -53,14 +54,27 @@ def test_grounding_dino_tiny(): assert [res["label"] for res in result] == ["coin"] * 24 -def test_owl(): +def test_owl_v2_image(): img = ski.data.coins() - result = owl_v2( + result = owl_v2_image( prompt="coin", image=img, ) - assert len(result) == 25 - assert [res["label"] for res in result] == ["coin"] * 25 + assert 24 <= len(result) <= 26 + assert [res["label"] for res in result] == ["coin"] * len(result) + + +def test_owl_v2_video(): + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = owl_v2_video( + prompt="coin", + frames=frames, + ) + + assert len(result) == 10 + assert 24 <= len([res["label"] for res in result[0]]) <= 26 def test_object_detection(): @@ -108,7 +122,7 @@ def test_florence2_sam2_video(): frames = [ np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) ] - result = florence2_sam2_video( + result = florence2_sam2_video_tracking( prompt="coin", frames=frames, ) diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py index 29262245..246c5642 100644 --- a/tests/integration_dev/test_tools.py +++ b/tests/integration_dev/test_tools.py @@ -1,9 +1,6 @@ import skimage as ski -from vision_agent.tools import ( - countgd_counting, - countgd_example_based_counting, -) +from vision_agent.tools import countgd_counting, countgd_example_based_counting def test_countgd_counting() -> None: diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index b4c8a9bf..df68372c 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -70,30 +70,64 @@ 2. Create a dictionary where the keys are the tool name and the values are the tool outputs. Remove numpy arrays from the printed dictionary. 3. Your test case MUST run only on the given images which are {media} 4. Print this final dictionary. +5. For video input, sample at 1 FPS and use the first 10 frames only to reduce processing time. **Example**: +--- EXAMPLE1 --- plan1: - Load the image from the provided file path 'image.jpg'. -- Use the 'owl_v2' tool with the prompt 'person' to detect and count the number of people in the image. +- Use the 'owl_v2_image' tool with the prompt 'person' to detect and count the number of people in the image. plan2: - Load the image from the provided file path 'image.jpg'. -- Use the 'grounding_sam' tool with the prompt 'person' to detect and count the number of people in the image. +- Use the 'florence2_sam2_image' tool with the prompt 'person' to detect and count the number of people in the image. - Count the number of detected objects labeled as 'person'. plan3: - Load the image from the provided file path 'image.jpg'. - Use the 'countgd_counting' tool to count the dominant foreground object, which in this case is people. ```python -from vision_agent.tools import load_image, owl_v2, grounding_sam, countgd_counting +from vision_agent.tools import load_image, owl_v2_image, florence2_sam2_image, countgd_counting image = load_image("image.jpg") -owl_v2_out = owl_v2("person", image) +owl_v2_out = owl_v2_image("person", image) -gsam_out = grounding_sam("person", image) -gsam_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in gsam_out] +f2s2_out = florence2_sam2_image("person", image) +# strip out the masks from the output becuase they don't provide useful information when printed +f2s2_out = [{{k: v for k, v in o.items() if k != "mask"}} for o in f2s2_out] cgd_out = countgd_counting(image) -final_out = {{"owl_v2": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} +final_out = {{"owl_v2_image": owl_v2_out, "florence2_sam2_image": f2s2, "countgd_counting": cgd_out}} +print(final_out) + +--- EXAMPLE2 --- +plan1: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool. +- Use the 'owl_v2_image' tool with the prompt 'person' to detect where the people are in the video. +plan2: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool. +- Use the 'florence2_phrase_grounding' tool with the prompt 'person' to detect where the people are in the video. +plan3: +- Extract frames from 'video.mp4' at 10 FPS using the 'extract_frames' tool. +- Use the 'countgd_counting' tool with the prompt 'person' to detect where the people are in the video. + + +```python +from vision_agent.tools import extract_frames, owl_v2_image, florence2_phrase_grounding, countgd_counting + +# sample at 1 FPS and use the first 10 frames to reduce processing time +frames = extract_frames("video.mp4", 1) +frames = [f[0] for f in frames][:10] + +# plan1 +owl_v2_out = [owl_v2_image("person", f) for f in frames] + +# plan2 +florence2_out = [florence2_phrase_grounding("person", f) for f in frames] + +# plan3 +countgd_out = [countgd_counting(f) for f in frames] + +final_out = {{"owl_v2_image": owl_v2_out, "florencev2_object_detection": florencev2_out, "countgd_counting": cgd_out}} print(final_out) ``` """ diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 90858569..f7b1e4c0 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -27,7 +27,7 @@ florence2_phrase_grounding, florence2_roberta_vqa, florence2_sam2_image, - florence2_sam2_video, + florence2_sam2_video_tracking, generate_pose_image, generate_soft_edge_image, get_tool_documentation, @@ -46,7 +46,8 @@ overlay_counting_results, overlay_heat_map, overlay_segmentation_masks, - owl_v2, + owl_v2_image, + owl_v2_video, save_image, save_json, save_video, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 31d53f98..e8e23ba6 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -145,15 +145,15 @@ def grounding_dino( return return_data -def owl_v2( +def owl_v2_image( prompt: str, image: np.ndarray, box_threshold: float = 0.10, ) -> List[Dict[str, Any]]: - """'owl_v2' is a tool that can detect and count multiple objects given a text - prompt such as category names or referring expressions. The categories in text - prompt are separated by commas. It returns a list of bounding boxes with normalized - coordinates, label names and associated probability scores. + """'owl_v2_image' is a tool that can detect and count multiple objects given a text + prompt such as category names or referring expressions on images. The categories in + text prompt are separated by commas. It returns a list of bounding boxes with + normalized coordinates, label names and associated probability scores. Parameters: prompt (str): The prompt to ground to the image. @@ -170,32 +170,103 @@ def owl_v2( Example ------- - >>> owl_v2("car, dinosaur", image) + >>> owl_v2_image("car, dinosaur", image) [ {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}, {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, ] """ image_size = image.shape[:2] - image_b64 = convert_to_b64(image) - request_data = { + buffer_bytes = numpy_to_bytes(image) + files = [("image", buffer_bytes)] + payload = { "prompts": [s.strip() for s in prompt.split(",")], - "image": image_b64, - "confidence": box_threshold, - "function_name": "owl_v2", + "model": "owlv2", + "function_name": "owl_v2_image", } - data: Dict[str, Any] = send_inference_request(request_data, "owlv2", v2=True) - return_data = [] + resp_data = send_inference_request( + payload, "text-to-object-detection", files=files, v2=True + ) + bboxes = resp_data[0] + bboxes_formatted = [ + ODResponseData( + label=bbox["label"], + bbox=normalize_bbox(bbox["bounding_box"], image_size), + score=round(bbox["score"], 2), + ) + for bbox in bboxes + ] + filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) + return [bbox.model_dump() for bbox in filtered_bboxes] + + +def owl_v2_video( + prompt: str, + frames: List[np.ndarray], + box_threshold: float = 0.10, +) -> List[List[Dict[str, Any]]]: + """'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple + objects per frame given a text prompt sucha s a category name or referring + expression. The categories in text prompt are separated by commas. It returns a list + of lists where each inner list contains the score, label, and bounding box of the + detections for that frame. + + Parameters: + prompt (str): The prompt to ground to the video. + frames (List[np.ndarray]): The list of frames to ground the prompt to. + box_threshold (float, optional): The threshold for the box detection. Defaults + to 0.30. + + Returns: + List[List[Dict[str, Any]]]: A list of lists of dictionaries containing the + score, label, and bounding box of the detected objects with normalized + coordinates between 0 and 1 (xmin, ymin, xmax, ymax). xmin and ymin are the + coordinates of the top-left and xmax and ymax are the coordinates of the + bottom-right of the bounding box. + + Example + ------- + >>> owl_v2_video("car, dinosaur", frames) + [ + [ + {'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}, + {'score': 0.98, 'label': 'car', 'bbox': [0.2, 0.21, 0.45, 0.5}, + ], + ... + ] + """ + if len(frames) == 0: + raise ValueError("No frames provided") + + image_size = frames[0].shape[:2] + buffer_bytes = frames_to_bytes(frames) + files = [("video", buffer_bytes)] + payload = { + "prompts": [s.strip() for s in prompt.split(",")], + "model": "owlv2", + "function_name": "owl_v2_video", + } + data: Dict[str, Any] = send_inference_request( + payload, "text-to-object-detection", files=files, v2=True + ) + bboxes_formatted = [] if data is not None: - for elt in data: - return_data.append( - { - "bbox": normalize_bbox(elt["bbox"], image_size), # type: ignore - "label": elt["label"], # type: ignore - "score": round(elt["score"], 2), # type: ignore - } - ) - return return_data + for frame_data in data: + bboxes_formated_frame = [] + for elt in frame_data: + bboxes_formated_frame.append( + ODResponseData( + label=elt["label"], # type: ignore + bbox=normalize_bbox(elt["bounding_box"], image_size), # type: ignore + score=round(elt["score"], 2), # type: ignore + ) + ) + bboxes_formatted.append(bboxes_formated_frame) + + filtered_bboxes = [ + filter_bboxes_by_threshold(elt, box_threshold) for elt in bboxes_formatted + ] + return [[bbox.model_dump() for bbox in frame] for frame in filtered_bboxes] def grounding_sam( @@ -317,14 +388,14 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]] return return_data -def florence2_sam2_video( +def florence2_sam2_video_tracking( prompt: str, frames: List[np.ndarray] ) -> List[List[Dict[str, Any]]]: - """'florence2_sam2_video' is a tool that can segment and track multiple entities - in a video given a text prompt such as category names or referring expressions. You - can optionally separate the categories in the text with commas. It only tracks - entities present in the first frame and only returns segmentation masks. It is - useful for tracking and counting without duplicating counts. + """'florence2_sam2_video_tracking' is a tool that can segment and track multiple + entities in a video given a text prompt such as category names or referring + expressions. You can optionally separate the categories in the text with commas. It + only tracks entities present in the first frame and only returns segmentation + masks. It is useful for tracking and counting without duplicating counts. Parameters: prompt (str): The prompt to ground to the video. @@ -351,14 +422,15 @@ def florence2_sam2_video( [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), }, ], + ... ] """ buffer_bytes = frames_to_bytes(frames) files = [("video", buffer_bytes)] payload = { - "prompts": prompt.split(","), - "function_name": "florence2_sam2_video", + "prompts": [s.strip() for s in prompt.split(",")], + "function_name": "florence2_sam2_video_tracking", } data: Dict[str, Any] = send_inference_request( payload, "florence2-sam2", files=files, v2=True @@ -549,7 +621,14 @@ def countgd_counting( payload, "text-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] - bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + bboxes_formatted = [ + ODResponseData( + label=bbox["label"], + bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + score=round(bbox["score"], 2), + ) + for bbox in bboxes_per_frame + ] filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) return [bbox.model_dump() for bbox in filtered_bboxes] @@ -601,7 +680,14 @@ def countgd_example_based_counting( payload, "visual-prompts-to-object-detection", files=files, metadata=metadata ) bboxes_per_frame = resp_data[0] - bboxes_formatted = [ODResponseData(**bbox) for bbox in bboxes_per_frame] + bboxes_formatted = [ + ODResponseData( + label=bbox["label"], + bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + score=round(bbox["score"], 2), + ) + for bbox in bboxes_per_frame + ] filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold) return [bbox.model_dump() for bbox in filtered_bboxes] @@ -1374,12 +1460,12 @@ def closest_box_distance( def extract_frames( video_uri: Union[str, Path], fps: float = 1 ) -> List[Tuple[np.ndarray, float]]: - """'extract_frames' extracts frames from a video which can be a file path or youtube - link, returns a list of tuples (frame, timestamp), where timestamp is the relative - time in seconds where the frame was captured. The frame is a numpy array. + """'extract_frames' extracts frames from a video which can be a file path, url or + youtube link, returns a list of tuples (frame, timestamp), where timestamp is the + relative time in seconds where the frame was captured. The frame is a numpy array. Parameters: - video_uri (Union[str, Path]): The path to the video file or youtube link + video_uri (Union[str, Path]): The path to the video file, url or youtube link fps (float, optional): The frame rate per second to extract the frames. Defaults to 10. @@ -1820,7 +1906,8 @@ def overlay_counting_results( FUNCTION_TOOLS = [ - owl_v2, + owl_v2_image, + owl_v2_video, ocr, clip, vit_image_classification, @@ -1829,7 +1916,7 @@ def overlay_counting_results( florence2_image_caption, florence2_ocr, florence2_sam2_image, - florence2_sam2_video, + florence2_sam2_video_tracking, florence2_phrase_grounding, ixc25_image_vqa, ixc25_video_vqa, diff --git a/vision_agent/utils/video.py b/vision_agent/utils/video.py index 51774279..d306f295 100644 --- a/vision_agent/utils/video.py +++ b/vision_agent/utils/video.py @@ -4,8 +4,8 @@ from functools import lru_cache from typing import List, Optional, Tuple -import cv2 import av # type: ignore +import cv2 import numpy as np from decord import VideoReader # type: ignore