diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 7e5636c2..20dc503a 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,15 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool +from .tools import ( + CLIP, + TOOLS, + BboxArea, + BboxIoU, + Counter, + Crop, + ExtractFrames, + GroundingDINO, + GroundingSAM, + SegArea, + SegIoU, + Tool, +) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index b543fa9d..55bbd477 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -92,7 +92,7 @@ class CLIP(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: """Invoke the CLIP model. Parameters: @@ -122,7 +122,7 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict rets = [] for elt in resp_json["data"]: rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) - return cast(List[Dict], rets) + return cast(Dict, rets[0]) class GroundingDINO(Tool): @@ -168,7 +168,7 @@ class GroundingDINO(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]: + def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: """Invoke the Grounding DINO model. Parameters: @@ -204,7 +204,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict if "scores" in elt: elt["scores"] = [round(score, 2) for score in elt["scores"]] elt["size"] = (image_size[1], image_size[0]) - return cast(List[Dict], resp_data) + return cast(Dict, resp_data) class GroundingSAM(Tool): @@ -259,7 +259,7 @@ class GroundingSAM(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: """Invoke the Grounding SAM model. Parameters: @@ -294,7 +294,7 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict ret_pred["labels"].append(pred["label_name"]) ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) ret_pred["masks"].append(mask) - return [ret_pred] + return ret_pred class AgentGroundingSAM(GroundingSAM): @@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM): returns the file name. This makes it easier for agents to use. """ - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: rets = super().__call__(prompt, image) - for ret in rets: - mask_files = [] - for mask in ret["masks"]: - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: - Image.fromarray(mask * 255).save(tmp) - mask_files.append(tmp.name) - ret["masks"] = mask_files + mask_files = [] + for mask in rets["masks"]: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + rets["masks"] = mask_files return rets @@ -363,7 +362,7 @@ class Crop(Tool): ], } - def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: + def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict: pil_image = Image.open(image) width, height = pil_image.size bbox = [ @@ -373,10 +372,10 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: int(bbox[3] * height), ] cropped_image = pil_image.crop(bbox) # type: ignore - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: cropped_image.save(tmp.name) - return tmp.name + return {"image": tmp.name} class BboxArea(Tool): @@ -432,15 +431,23 @@ def __call__(self, masks: Union[str, Path]) -> float: class BboxIoU(Tool): name = "bbox_iou_" - description = "'bbox_iou_' returns the intersection over union of two bounding boxes." + description = ( + "'bbox_iou_' returns the intersection over union of two bounding boxes." + ) usage = { - "required_parameters": [{"name": "bbox1", "type": "List[int]"}, {"name": "bbox2", "type": "List[int]"}], + "required_parameters": [ + {"name": "bbox1", "type": "List[int]"}, + {"name": "bbox2", "type": "List[int]"}, + ], "examples": [ { "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", - "parameters": {"bbox1": [0.2, 0.21, 0.34, 0.42], "bbox2": [0.3, 0.31, 0.44, 0.52]}, + "parameters": { + "bbox1": [0.2, 0.21, 0.34, 0.42], + "bbox2": [0.3, 0.31, 0.44, 0.52], + }, } - ] + ], } def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: @@ -459,13 +466,16 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: class SegIoU(Tool): name = "seg_iou_" - description = "'seg_iou_' returns the intersection over union of two segmentation masks." + description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files." usage = { - "required_parameters": [{"name": "mask1", "type": "str"}, {"name": "mask2", "type": "str"}], + "required_parameters": [ + {"name": "mask1", "type": "str"}, + {"name": "mask2", "type": "str"}, + ], "examples": [ { - "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask1.png and mask2.png", - "parameters": {"mask1": "mask1.png", "mask2": "mask2.png"}, + "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg", + "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"}, } ], } @@ -481,6 +491,47 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: return round(iou, 2) +class ExtractFrames(Tool): + r"""Extract frames from a video.""" + + name = "extract_frames_" + description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path." + usage = { + "required_parameters": [{"name": "video_uri", "type": "str"}], + "examples": [ + { + "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4", + "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"}, + }, + { + "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4", + "parameters": {"video_uri": "tests/data/test.mp4"}, + }, + ], + } + + def __call__(self, video_uri: str) -> List[Tuple[str, float]]: + """Extract frames from a video. + + + Parameters: + video_uri: the path to the video file or a url points to the video data + + Returns: + a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. + """ + frames = extract_frames_from_video(video_uri) + result = [] + _LOGGER.info( + f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks." + ) + for frame, ts in frames: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(frame).save(tmp) + result.append((tmp.name, ts)) + return result + + class Add(Tool): r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places.""" @@ -557,47 +608,6 @@ def __call__(self, input: List[int]) -> float: return round(input[0] / input[1], 2) -class ExtractFrames(Tool): - r"""Extract frames from a video.""" - - name = "extract_frames_" - description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame." - usage = { - "required_parameters": [{"name": "video_uri", "type": "str"}], - "examples": [ - { - "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4", - "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"}, - }, - { - "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4", - "parameters": {"video_uri": "tests/data/test.mp4"}, - }, - ], - } - - def __call__(self, video_uri: str) -> list[tuple[str, float]]: - """Extract frames from a video. - - - Parameters: - video_uri: the path to the video file or a url points to the video data - - Returns: - a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. - """ - frames = extract_frames_from_video(video_uri) - result = [] - _LOGGER.info( - f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks." - ) - for frame, ts in frames: - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: - Image.fromarray(frame).save(tmp) - result.append((tmp.name, ts)) - return result - - TOOLS = { i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c} for i, c in enumerate( @@ -605,6 +615,7 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]: CLIP, GroundingDINO, AgentGroundingSAM, + ExtractFrames, Counter, Crop, BboxArea, @@ -615,7 +626,6 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]: Subtract, Multiply, Divide, - ExtractFrames, ] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) diff --git a/vision_agent/tools/video.py b/vision_agent/tools/video.py index 606166ed..6068725f 100644 --- a/vision_agent/tools/video.py +++ b/vision_agent/tools/video.py @@ -22,12 +22,16 @@ def extract_frames_from_video( Parameters: video_uri: the path to the video file or a video file url fps: the frame rate per second to extract the frames - motion_detection_threshold: The threshold to detect motion between changes/frames. - A value between 0-1, which represents the percentage change required for the frames to be considered in motion. - For example, a lower value means more frames will be extracted. + motion_detection_threshold: The threshold to detect motion between + changes/frames. A value between 0-1, which represents the percentage change + required for the frames to be considered in motion. For example, a lower + value means more frames will be extracted. Returns: - a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. + a list of tuples containing the extracted frame and the timestamp in seconds. + E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds + from the start of the video. E.g. 12.125 means 12.125 seconds from the start of + the video. The frames are sorted by the timestamp in ascending order. """ with VideoFileClip(video_uri) as video: video_duration: float = video.duration