Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding bbox stats function, adding optional param to frame extraction #77

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _handle_extract_frames(
# any following processing
for video_file_output in tool_result["call_results"]:
# When the video tool is run with wrong parameters, exit the loop
if len(video_file_output) < 2:
if not isinstance(video_file_output, tuple) or len(video_file_output) < 2:
break
for frame, _ in video_file_output:
image = frame
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
CLIP,
OCR,
TOOLS,
BboxArea,
BboxStats,
BboxIoU,
ObjectDistance,
BoxDistance,
Expand Down
94 changes: 58 additions & 36 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ class GroundingDINO(Tool):
"""

name = "grounding_dino_"
description = "'grounding_dino_' is a tool that can detect and count objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
description = "'grounding_dino_' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions. It returns a list and count of bounding boxes, label names and associated probability scores."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
{"name": "image", "type": "str"},
],
"optional_parameters": [
{"name": "box_threshold", "type": "float"},
{"name": "iou_threshold", "type": "float"},
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
],
"examples": [
{
Expand All @@ -209,7 +209,7 @@ class GroundingDINO(Tool):
"prompt": "red shirt. green shirt",
"image": "shirts.jpg",
"box_threshold": 0.20,
"iou_threshold": 0.75,
"iou_threshold": 0.20,
},
},
],
Expand All @@ -221,7 +221,7 @@ def __call__(
prompt: str,
image: Union[str, Path, ImageType],
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> Dict:
"""Invoke the Grounding DINO model.

Expand Down Expand Up @@ -249,7 +249,7 @@ def __call__(
data["scores"] = [round(score, 2) for score in data["scores"]]
if "labels" in data:
data["labels"] = list(data["labels"])
data["size"] = (image_size[1], image_size[0])
data["image_size"] = image_size
return data


Expand Down Expand Up @@ -277,15 +277,15 @@ class GroundingSAM(Tool):
"""

name = "grounding_sam_"
description = "'grounding_sam_' is a tool that can detect and segment objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
description = "'grounding_sam_' is a tool that can detect and segment multiple objects given a text prompt such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
{"name": "image", "type": "str"},
],
"optional_parameters": [
{"name": "box_threshold", "type": "float"},
{"name": "iou_threshold", "type": "float"},
{"name": "box_threshold", "type": "float", "min": 0.1, "max": 0.5},
{"name": "iou_threshold", "type": "float", "min": 0.01, "max": 0.99},
],
"examples": [
{
Expand All @@ -312,7 +312,7 @@ class GroundingSAM(Tool):
"prompt": "red shirt, green shirt",
"image": "shirts.jpg",
"box_threshold": 0.20,
"iou_threshold": 0.75,
"iou_threshold": 0.20,
},
},
],
Expand All @@ -324,7 +324,7 @@ def __call__(
prompt: str,
image: Union[str, ImageType],
box_threshold: float = 0.2,
iou_threshold: float = 0.75,
iou_threshold: float = 0.2,
) -> Dict:
"""Invoke the Grounding SAM model.

Expand Down Expand Up @@ -353,6 +353,7 @@ def __call__(
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"]
]
data["image_size"] = image_size
data.pop("mask_shape", None)
return data

Expand Down Expand Up @@ -435,6 +436,8 @@ def __call__(
for mask in data["masks"]
]
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
mask_shape = data.pop("mask_shape", None)
data["image_size"] = (mask_shape[0], mask_shape[1]) if mask_shape else None
return data


Expand Down Expand Up @@ -790,33 +793,49 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
return {"image": tmp.name}


class BboxArea(Tool):
r"""BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places."""
class BboxStats(Tool):
r"""BboxStats returns the height, width and area of the bounding box in pixels to 2 decimal places."""

name = "bbox_area_"
description = "'bbox_area_' returns the area of the given bounding box in pixels normalized to 2 decimal places."
name = "bbox_stats_"
description = "'bbox_stats_' returns the height, width and area of the given bounding box in pixels to 2 decimal places."
usage = {
"required_parameters": [{"name": "bboxes", "type": "List[int]"}],
"required_parameters": [
{"name": "bboxes", "type": "List[int]"},
{"name": "image_size", "type": "Tuple[int]"},
],
"examples": [
{
"scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]},
}
"scenario": "Calculate the width and height of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
"image_size": (500, 1200),
},
},
{
"scenario": "Calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]",
"parameters": {
"bboxes": [[0.2, 0.21, 0.34, 0.42]],
"image_size": (640, 480),
},
},
],
}

def __call__(self, bboxes: List[Dict]) -> List[Dict]:
def __call__(
self, bboxes: List[List[int]], image_size: Tuple[int, int]
) -> List[Dict]:
areas = []
for elt in bboxes:
height, width = elt["size"]
for label, bbox in zip(elt["labels"], elt["bboxes"]):
x1, y1, x2, y2 = bbox
areas.append(
{
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
"label": label,
}
)
height, width = image_size
for bbox in bboxes:
x1, y1, x2, y2 = bbox
areas.append(
{
"width": round((x2 - x1) * width, 2),
"height": round((y2 - y1) * height, 2),
"area": round((x2 - x1) * (y2 - y1) * width * height, 2),
}
)

return areas


Expand Down Expand Up @@ -1055,22 +1074,25 @@ class ExtractFrames(Tool):
r"""Extract frames from a video."""

name = "extract_frames_"
description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
description = "'extract_frames_' extracts frames from a video every 2 seconds, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path."
usage = {
"required_parameters": [{"name": "video_uri", "type": "str"}],
"optional_parameters": [{"name": "frames_every", "type": "float"}],
"examples": [
{
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
},
{
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4"},
"scenario": "Can you extract the images from this video file at every 2 seconds ? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4", "frames_every": 2},
},
],
}

def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
def __call__(
self, video_uri: str, frames_every: float = 2
) -> List[Tuple[str, float]]:
"""Extract frames from a video.


Expand All @@ -1080,7 +1102,7 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
"""
frames = extract_frames_from_video(video_uri)
frames = extract_frames_from_video(video_uri, fps=round(1 / frames_every, 2))
result = []
_LOGGER.info(
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
Expand Down Expand Up @@ -1183,7 +1205,7 @@ def __call__(self, equation: str) -> float:
AgentDINOv,
ExtractFrames,
Crop,
BboxArea,
BboxStats,
SegArea,
ObjectDistance,
BboxContains,
Expand Down
Loading