Skip to content

Commit

Permalink
update tool return format
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 29, 2024
1 parent 0507f6a commit d9bdcf8
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 72 deletions.
15 changes: 14 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
from .tools import (
CLIP,
TOOLS,
BboxArea,
BboxIoU,
Counter,
Crop,
ExtractFrames,
GroundingDINO,
GroundingSAM,
SegArea,
SegIoU,
Tool,
)
144 changes: 77 additions & 67 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class CLIP(Tool):
}

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
"""Invoke the CLIP model.
Parameters:
Expand Down Expand Up @@ -122,7 +122,7 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict
rets = []
for elt in resp_json["data"]:
rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
return cast(List[Dict], rets)
return cast(Dict, rets[0])


class GroundingDINO(Tool):
Expand Down Expand Up @@ -168,7 +168,7 @@ class GroundingDINO(Tool):
}

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]:
def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict:
"""Invoke the Grounding DINO model.
Parameters:
Expand Down Expand Up @@ -204,7 +204,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict
if "scores" in elt:
elt["scores"] = [round(score, 2) for score in elt["scores"]]
elt["size"] = (image_size[1], image_size[0])
return cast(List[Dict], resp_data)
return cast(Dict, resp_data)


class GroundingSAM(Tool):
Expand Down Expand Up @@ -259,7 +259,7 @@ class GroundingSAM(Tool):
}

# TODO: Add support for input multiple images, which aligns with the output type.
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
"""Invoke the Grounding SAM model.
Parameters:
Expand Down Expand Up @@ -294,23 +294,22 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict
ret_pred["labels"].append(pred["label_name"])
ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
ret_pred["masks"].append(mask)
return [ret_pred]
return ret_pred


class AgentGroundingSAM(GroundingSAM):
r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files
returns the file name. This makes it easier for agents to use.
"""

def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
rets = super().__call__(prompt, image)
for ret in rets:
mask_files = []
for mask in ret["masks"]:
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
ret["masks"] = mask_files
mask_files = []
for mask in rets["masks"]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
rets["masks"] = mask_files
return rets


Expand Down Expand Up @@ -363,7 +362,7 @@ class Crop(Tool):
],
}

def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict:
pil_image = Image.open(image)
width, height = pil_image.size
bbox = [
Expand All @@ -373,10 +372,10 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
int(bbox[3] * height),
]
cropped_image = pil_image.crop(bbox) # type: ignore
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
cropped_image.save(tmp.name)

return tmp.name
return {"image": tmp.name}


class BboxArea(Tool):
Expand Down Expand Up @@ -432,15 +431,23 @@ def __call__(self, masks: Union[str, Path]) -> float:

class BboxIoU(Tool):
name = "bbox_iou_"
description = "'bbox_iou_' returns the intersection over union of two bounding boxes."
description = (
"'bbox_iou_' returns the intersection over union of two bounding boxes."
)
usage = {
"required_parameters": [{"name": "bbox1", "type": "List[int]"}, {"name": "bbox2", "type": "List[int]"}],
"required_parameters": [
{"name": "bbox1", "type": "List[int]"},
{"name": "bbox2", "type": "List[int]"},
],
"examples": [
{
"scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]",
"parameters": {"bbox1": [0.2, 0.21, 0.34, 0.42], "bbox2": [0.3, 0.31, 0.44, 0.52]},
"parameters": {
"bbox1": [0.2, 0.21, 0.34, 0.42],
"bbox2": [0.3, 0.31, 0.44, 0.52],
},
}
]
],
}

def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:
Expand All @@ -459,13 +466,16 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float:

class SegIoU(Tool):
name = "seg_iou_"
description = "'seg_iou_' returns the intersection over union of two segmentation masks."
description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files."
usage = {
"required_parameters": [{"name": "mask1", "type": "str"}, {"name": "mask2", "type": "str"}],
"required_parameters": [
{"name": "mask1", "type": "str"},
{"name": "mask2", "type": "str"},
],
"examples": [
{
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask1.png and mask2.png",
"parameters": {"mask1": "mask1.png", "mask2": "mask2.png"},
"scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg",
"parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"},
}
],
}
Expand All @@ -481,6 +491,47 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float:
return round(iou, 2)


class ExtractFrames(Tool):
r"""Extract frames from a video."""

name = "extract_frames_"
description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path."
usage = {
"required_parameters": [{"name": "video_uri", "type": "str"}],
"examples": [
{
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
},
{
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4"},
},
],
}

def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
"""Extract frames from a video.
Parameters:
video_uri: the path to the video file or a url points to the video data
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
"""
frames = extract_frames_from_video(video_uri)
result = []
_LOGGER.info(
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
)
for frame, ts in frames:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(frame).save(tmp)
result.append((tmp.name, ts))
return result


class Add(Tool):
r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places."""

Expand Down Expand Up @@ -557,54 +608,14 @@ def __call__(self, input: List[int]) -> float:
return round(input[0] / input[1], 2)


class ExtractFrames(Tool):
r"""Extract frames from a video."""

name = "extract_frames_"
description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame."
usage = {
"required_parameters": [{"name": "video_uri", "type": "str"}],
"examples": [
{
"scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4",
"parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"},
},
{
"scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4",
"parameters": {"video_uri": "tests/data/test.mp4"},
},
],
}

def __call__(self, video_uri: str) -> list[tuple[str, float]]:
"""Extract frames from a video.
Parameters:
video_uri: the path to the video file or a url points to the video data
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
"""
frames = extract_frames_from_video(video_uri)
result = []
_LOGGER.info(
f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks."
)
for frame, ts in frames:
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.fromarray(frame).save(tmp)
result.append((tmp.name, ts))
return result


TOOLS = {
i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
for i, c in enumerate(
[
CLIP,
GroundingDINO,
AgentGroundingSAM,
ExtractFrames,
Counter,
Crop,
BboxArea,
Expand All @@ -615,7 +626,6 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]:
Subtract,
Multiply,
Divide,
ExtractFrames,
]
)
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
Expand Down
12 changes: 8 additions & 4 deletions vision_agent/tools/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ def extract_frames_from_video(
Parameters:
video_uri: the path to the video file or a video file url
fps: the frame rate per second to extract the frames
motion_detection_threshold: The threshold to detect motion between changes/frames.
A value between 0-1, which represents the percentage change required for the frames to be considered in motion.
For example, a lower value means more frames will be extracted.
motion_detection_threshold: The threshold to detect motion between
changes/frames. A value between 0-1, which represents the percentage change
required for the frames to be considered in motion. For example, a lower
value means more frames will be extracted.
Returns:
a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order.
a list of tuples containing the extracted frame and the timestamp in seconds.
E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds
from the start of the video. E.g. 12.125 means 12.125 seconds from the start of
the video. The frames are sorted by the timestamp in ascending order.
"""
with VideoFileClip(video_uri) as video:
video_duration: float = video.duration
Expand Down

0 comments on commit d9bdcf8

Please sign in to comment.