Skip to content

Commit

Permalink
adding bbox stats function, adding optional param to frame extraction (
Browse files Browse the repository at this point in the history
…#77)

* adding bbox stats function, adding optional param to frame extraction

* fix linting
  • Loading branch information
shankar-vision-eng authored May 9, 2024
1 parent 39706a0 commit da09f6d
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 38 deletions.
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _handle_extract_frames(
# any following processing
for video_file_output in tool_result["call_results"]:
# When the video tool is run with wrong parameters, exit the loop
if len(video_file_output) < 2:
if not isinstance(video_file_output, tuple) or len(video_file_output) < 2:
break
for frame, _ in video_file_output:
image = frame
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CLIP,
OCR,
TOOLS,
BboxArea,
BboxStats,
BboxIoU,
BoxDistance,
Crop,
Expand Down
94 changes: 58 additions & 36 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ class GroundingDINO(Tool):
"""

name = "grounding_dino_"
description = "'grounding_dino_' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
description = "'grounding_dino_' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
{"name": "image", "type": "str"},
],
"optional_parameters": [
{"name": "box_threshold", "type": "float"},
{"name": "iou_threshold", "type": "float"},
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
],
"examples": [
{
Expand All @@ -209,7 +209,7 @@ class GroundingDINO(Tool):
"prompt": "red shirt. green shirt",
"image": "shirts.jpg",
"box_threshold": 0.20,
"iou_threshold": 0.75,
"iou_threshold": 0.20,
},
},
],
Expand All @@ -221,7 +221,7 @@ def __call__(
prompt: str,
image: Union[str, Path, ImageType],
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> Dict:
"""Invoke the Grounding DINO model.
Expand Down Expand Up @@ -249,7 +249,7 @@ def __call__(
data["scores"] = [round(score, 2) for score in data["scores"]]
if "labels" in data:
data["labels"] = list(data["labels"])
data["size"] = (image_size[1], image_size[0])
data["image_size"] = image_size
return data


Expand Down Expand Up @@ -277,15 +277,15 @@ class GroundingSAM(Tool):
"""

name = "grounding_sam_"
description = "'grounding_sam_' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
description = "'grounding_sam_' is a tool that can detect and segment multiple objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
{"name": "image", "type": "str"},
],
"optional_parameters": [
{"name": "box_threshold", "type": "float"},
{"name": "iou_threshold", "type": "float"},
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
],
"examples": [
{
Expand All @@ -312,7 +312,7 @@ class GroundingSAM(Tool):
"prompt": "red shirt, green shirt",
"image": "shirts.jpg",
"box_threshold": 0.20,
"iou_threshold": 0.75,
"iou_threshold": 0.20,
},
},
],
Expand All @@ -324,7 +324,7 @@ def __call__(
prompt: str,
image: Union[str, ImageType],
box_threshold: float = 0.2,
iou_threshold: float = 0.75,
iou_threshold: float = 0.2,
) -> Dict:
"""Invoke the Grounding SAM model.
Expand Down Expand Up @@ -353,6 +353,7 @@ def __call__(
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"]
]
data["image_size"] = image_size
data.pop("mask_shape", None)
return data

Expand Down Expand Up @@ -434,6 +435,8 @@ def __call__(
for mask in data["masks"]
]
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
mask_shape = data.pop("mask_shape", None)
data["image_size"] = (mask_shape[0], mask_shape[1]) if mask_shape else None
return data


Expand Down Expand Up @@ -789,33 +792,49 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
return {"image": tmp.name}


class BboxArea(Tool):
r"""BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places."""
class BboxStats(Tool):
r"""BboxStats returns the height, width and area of the bounding box in pixels to 2 decimal places."""

name = "bbox_area_"
description = "'bbox_area_' returns the area of the given bounding box in pixels normalized to 2 decimal places."
name = "bbox_stats_"
description = "'bbox_stats_' returns the height, width and area of the given bounding box in pixels to 2 decimal places."
usage = {
"required_parameters": [{"name": "bboxes", "type": "List[int]"}],
"required_parameters": [
{"name": "bboxes", "type": "List[int]"},
{"name": "image_size", "type": "Tuple[int]"},
],
"examples": [
{
"scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
}
"scenario": "Calculate the width and height of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
"image_size": (500, 1200),
},
},
{
"scenario": "Calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
"image_size": (640, 480),
},
},
],
}

def __call__(self, bboxes: List[Dict]) -> List[Dict]:
def __call__(
self, bboxes: List[List[int]], image_size: Tuple[int, int]
) -> List[Dict]:
areas = []
for elt in bboxes:
height, width = elt["size"]
for label, bbox in zip(elt["labels"], elt["bboxes"]):
x1, y1, x2, y2 = bbox
areas.append(
{
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
"label": label,
}
)
height, width = image_size
for bbox in bboxes:
x1, y1, x2, y2 = bbox
areas.append(
{
"width": round((x2 - x1) * width, 2),
"height": round((y2 - y1) * height, 2),
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
}
)

return areas


Expand Down Expand Up @@ -1054,22 +1073,25 @@ class ExtractFrames(Tool):
r"""Extract frames from a video."""

name = "extract_frames_"
description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
description = "'extract_frames_' extracts frames from a video every 2 seconds, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
usage = {
"required_parameters": [{"name": "video_uri", "type": "str"}],
"optional_parameters": [{"name": "frames_every", "type": "float"}],
"examples": [
{
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
},
{
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4"},
"scenario": "Can you extract the images from this video file at every 2 seconds ? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4", "frames_every": 2},
},
],
}

def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
def __call__(
self, video_uri: str, frames_every: float = 2
) -> List[Tuple[str, float]]:
"""Extract frames from a video.
Expand All @@ -1079,7 +1101,7 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
"""
frames = extract_frames_from_video(video_uri)
frames = extract_frames_from_video(video_uri, fps=round(1 / frames_every, 2))
result = []
_LOGGER.info(
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
Expand Down Expand Up @@ -1182,7 +1204,7 @@ def __call__(self, equation: str) -> float:
AgentDINOv,
ExtractFrames,
Crop,
BboxArea,
BboxStats,
SegArea,
ObjectDistance,
BboxContains,
Expand Down

0 comments on commit da09f6d

Please sign in to comment.