From da09f6d7d20548d3e6b50f84bd56d124e916e59c Mon Sep 17 00:00:00 2001 From: Shankar <90070882+shankar-landing-ai@users.noreply.github.com> Date: Thu, 9 May 2024 11:32:08 -0700 Subject: [PATCH] adding bbox stats function, adding optional param to frame extraction (#77) * adding bbox stats function, adding optional param to frame extraction * fix linting --- vision_agent/agent/vision_agent.py | 2 +- vision_agent/tools/__init__.py | 2 +- vision_agent/tools/tools.py | 94 ++++++++++++++++++------------ 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d80bf9de..0e63a70b 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -308,7 +308,7 @@ def _handle_extract_frames( # any following processing for video_file_output in tool_result["call_results"]: # When the video tool is run with wrong parameters, exit the loop - if len(video_file_output) < 2: + if not isinstance(video_file_output, tuple) or len(video_file_output) < 2: break for frame, _ in video_file_output: image = frame diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 5c713498..75b9830e 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -3,7 +3,7 @@ CLIP, OCR, TOOLS, - BboxArea, + BboxStats, BboxIoU, BoxDistance, Crop, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index f61e17b8..cad7f4ad 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -174,15 +174,15 @@ class GroundingDINO(Tool): """ name = "grounding_dino_" - description = "'grounding_dino_' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores." + description = "'grounding_dino_' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], "optional_parameters": [ - {"name": "box_threshold", "type": "float"}, - {"name": "iou_threshold", "type": "float"}, + {"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5}, + {"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99}, ], "examples": [ { @@ -209,7 +209,7 @@ class GroundingDINO(Tool): "prompt": "red shirt. green shirt", "image": "shirts.jpg", "box_threshold": 0.20, - "iou_threshold": 0.75, + "iou_threshold": 0.20, }, }, ], @@ -221,7 +221,7 @@ def __call__( prompt: str, image: Union[str, Path, ImageType], box_threshold: float = 0.20, - iou_threshold: float = 0.75, + iou_threshold: float = 0.20, ) -> Dict: """Invoke the Grounding DINO model. @@ -249,7 +249,7 @@ def __call__( data["scores"] = [round(score, 2) for score in data["scores"]] if "labels" in data: data["labels"] = list(data["labels"]) - data["size"] = (image_size[1], image_size[0]) + data["image_size"] = image_size return data @@ -277,15 +277,15 @@ class GroundingSAM(Tool): """ name = "grounding_sam_" - description = "'grounding_sam_' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores." + description = "'grounding_sam_' is a tool that can detect and segment multiple objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, {"name": "image", "type": "str"}, ], "optional_parameters": [ - {"name": "box_threshold", "type": "float"}, - {"name": "iou_threshold", "type": "float"}, + {"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5}, + {"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99}, ], "examples": [ { @@ -312,7 +312,7 @@ class GroundingSAM(Tool): "prompt": "red shirt, green shirt", "image": "shirts.jpg", "box_threshold": 0.20, - "iou_threshold": 0.75, + "iou_threshold": 0.20, }, }, ], @@ -324,7 +324,7 @@ def __call__( prompt: str, image: Union[str, ImageType], box_threshold: float = 0.2, - iou_threshold: float = 0.75, + iou_threshold: float = 0.2, ) -> Dict: """Invoke the Grounding SAM model. @@ -353,6 +353,7 @@ def __call__( rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"] ] + data["image_size"] = image_size data.pop("mask_shape", None) return data @@ -434,6 +435,8 @@ def __call__( for mask in data["masks"] ] data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))] + mask_shape = data.pop("mask_shape", None) + data["image_size"] = (mask_shape[0], mask_shape[1]) if mask_shape else None return data @@ -789,33 +792,49 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict: return {"image": tmp.name} -class BboxArea(Tool): - r"""BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places.""" +class BboxStats(Tool): + r"""BboxStats returns the height, width and area of the bounding box in pixels to 2 decimal places.""" - name = "bbox_area_" - description = "'bbox_area_' returns the area of the given bounding box in pixels normalized to 2 decimal places." + name = "bbox_stats_" + description = "'bbox_stats_' returns the height, width and area of the given bounding box in pixels to 2 decimal places." usage = { - "required_parameters": [{"name": "bboxes", "type": "List[int]"}], + "required_parameters": [ + {"name": "bboxes", "type": "List[int]"}, + {"name": "image_size", "type": "Tuple[int]"}, + ], "examples": [ { - "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]", - "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]}, - } + "scenario": "Calculate the width and height of the bounding box [0.2, 0.21, 0.34, 0.42]", + "parameters": { + "bboxes": [[0.2, 0.21, 0.34, 0.42]], + "image_size": (500, 1200), + }, + }, + { + "scenario": "Calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]", + "parameters": { + "bboxes": [[0.2, 0.21, 0.34, 0.42]], + "image_size": (640, 480), + }, + }, ], } - def __call__(self, bboxes: List[Dict]) -> List[Dict]: + def __call__( + self, bboxes: List[List[int]], image_size: Tuple[int, int] + ) -> List[Dict]: areas = [] - for elt in bboxes: - height, width = elt["size"] - for label, bbox in zip(elt["labels"], elt["bboxes"]): - x1, y1, x2, y2 = bbox - areas.append( - { - "area": round((x2 - x1) * (y2 - y1) * width * height, 2), - "label": label, - } - ) + height, width = image_size + for bbox in bboxes: + x1, y1, x2, y2 = bbox + areas.append( + { + "width": round((x2 - x1) * width, 2), + "height": round((y2 - y1) * height, 2), + "area": round((x2 - x1) * (y2 - y1) * width * height, 2), + } + ) + return areas @@ -1054,22 +1073,25 @@ class ExtractFrames(Tool): r"""Extract frames from a video.""" name = "extract_frames_" - description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path." + description = "'extract_frames_' extracts frames from a video every 2 seconds, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path." usage = { "required_parameters": [{"name": "video_uri", "type": "str"}], + "optional_parameters": [{"name": "frames_every", "type": "float"}], "examples": [ { "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4", "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"}, }, { - "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4", - "parameters": {"video_uri": "tests/data/test.mp4"}, + "scenario": "Can you extract the images from this video file at every 2 seconds ? Video path: tests/data/test.mp4", + "parameters": {"video_uri": "tests/data/test.mp4", "frames_every": 2}, }, ], } - def __call__(self, video_uri: str) -> List[Tuple[str, float]]: + def __call__( + self, video_uri: str, frames_every: float = 2 + ) -> List[Tuple[str, float]]: """Extract frames from a video. @@ -1079,7 +1101,7 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]: Returns: a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. """ - frames = extract_frames_from_video(video_uri) + frames = extract_frames_from_video(video_uri, fps=round(1 / frames_every, 2)) result = [] _LOGGER.info( f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks." @@ -1182,7 +1204,7 @@ def __call__(self, equation: str) -> float: AgentDINOv, ExtractFrames, Crop, - BboxArea, + BboxStats, SegArea, ObjectDistance, BboxContains,