From 98482e9a45b4a02ee4370021333a74826f47f026 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 27 Mar 2024 13:27:34 -0700 Subject: [PATCH 1/5] added iou tools --- tests/test_tools.py | 26 ++++++++++++++++++ vision_agent/tools/tools.py | 55 ++++++++++++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 tests/test_tools.py diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 00000000..00f85072 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,26 @@ +import os +import tempfile + +import numpy as np +from PIL import Image + +from vision_agent.tools.tools import BboxIoU, SegIoU + + +def test_bbox_iou(): + bbox1 = [0, 0, 0.75, 0.75] + bbox2 = [0.25, 0.25, 1, 1] + assert BboxIoU()(bbox1, bbox2) == 0.29 + + +def test_seg_iou(): + mask1 = np.zeros((10, 10), dtype=np.uint8) + mask1[2:4, 2:4] = 255 + mask2 = np.zeros((10, 10), dtype=np.uint8) + mask2[3:5, 3:5] = 255 + with tempfile.TemporaryDirectory() as tmpdir: + mask1_path = os.path.join(tmpdir, "mask1.png") + mask2_path = os.path.join(tmpdir, "mask2.png") + Image.fromarray(mask1).save(mask1_path) + Image.fromarray(mask2).save(mask2_path) + assert SegIoU()(mask1_path, mask2_path) == 0.14 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 6e55f210..b543fa9d 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -388,7 +388,7 @@ class BboxArea(Tool): "required_parameters": [{"name": "bbox", "type": "List[int]"}], "examples": [ { - "scenario": "If you want to calculate the area of the bounding box [0, 0, 100, 100]", + "scenario": "If you want to calculate the area of the bounding box [0.2, 0.21, 0.34, 0.42]", "parameters": {"bboxes": [0.2, 0.21, 0.34, 0.42]}, } ], @@ -430,6 +430,57 @@ def __call__(self, masks: Union[str, Path]) -> float: return cast(float, round(np.sum(np_mask) / 255, 2)) +class BboxIoU(Tool): + name = "bbox_iou_" + description = "'bbox_iou_' returns the intersection over union of two bounding boxes." + usage = { + "required_parameters": [{"name": "bbox1", "type": "List[int]"}, {"name": "bbox2", "type": "List[int]"}], + "examples": [ + { + "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", + "parameters": {"bbox1": [0.2, 0.21, 0.34, 0.42], "bbox2": [0.3, 0.31, 0.44, 0.52]}, + } + ] + } + + def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + xA = max(x1, x3) + yA = max(y1, y3) + xB = min(x2, x4) + yB = min(y2, y4) + inter_area = max(0, xB - xA) * max(0, yB - yA) + boxa_area = (x2 - x1) * (y2 - y1) + boxb_area = (x4 - x3) * (y4 - y3) + iou = inter_area / float(boxa_area + boxb_area - inter_area) + return round(iou, 2) + + +class SegIoU(Tool): + name = "seg_iou_" + description = "'seg_iou_' returns the intersection over union of two segmentation masks." + usage = { + "required_parameters": [{"name": "mask1", "type": "str"}, {"name": "mask2", "type": "str"}], + "examples": [ + { + "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask1.png and mask2.png", + "parameters": {"mask1": "mask1.png", "mask2": "mask2.png"}, + } + ], + } + + def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: + pil_mask1 = Image.open(str(mask1)) + pil_mask2 = Image.open(str(mask2)) + np_mask1 = np.clip(np.array(pil_mask1), 0, 1) + np_mask2 = np.clip(np.array(pil_mask2), 0, 1) + intersection = np.logical_and(np_mask1, np_mask2) + union = np.logical_or(np_mask1, np_mask2) + iou = np.sum(intersection) / np.sum(union) + return round(iou, 2) + + class Add(Tool): r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places.""" @@ -558,6 +609,8 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]: Crop, BboxArea, SegArea, + BboxIoU, + SegIoU, Add, Subtract, Multiply, From 0507f6af111a8a0647cdc8c9a152be2eb668f2f6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 29 Mar 2024 10:44:14 -0700 Subject: [PATCH 2/5] add image visualization for reflection --- tests/{ => tools}/test_tools.py | 0 vision_agent/agent/vision_agent.py | 92 ++++++++++++++++++++------ vision_agent/image_utils.py | 100 +++++++++++++++++++++++++++-- 3 files changed, 168 insertions(+), 24 deletions(-) rename tests/{ => tools}/test_tools.py (100%) diff --git a/tests/test_tools.py b/tests/tools/test_tools.py similarity index 100% rename from tests/test_tools.py rename to tests/tools/test_tools.py diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 5d34fe9e..caba0533 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,11 +1,14 @@ import json import logging import sys +import tempfile +from os import walk from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from tabulate import tabulate +from vision_agent.image_utils import overlay_bboxes, overlay_masks from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -248,12 +251,12 @@ def retrieval( tools: Dict[int, Any], previous_log: str, reflections: str, -) -> Tuple[List[Dict], str]: +) -> Tuple[Dict, str]: tool_id = choose_tool( model, question, {k: v["description"] for k, v in tools.items()}, reflections ) if tool_id is None: - return [{}], "" + return {}, "" _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})") tool_instructions = tools[tool_id] @@ -265,14 +268,12 @@ def retrieval( ) _LOGGER.info(f"\tParameters: {parameters} for {tool_name}") if parameters is None: - return [{}], "" - tool_results = [ - {"task": question, "tool_name": tool_name, "parameters": parameters} - ] + return {}, "" + tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters} _LOGGER.info( - f"""Going to run the following {len(tool_results)} tool(s) in sequence: -{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}""" + f"""Going to run the following tool(s) in sequence: +{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}""" ) def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: @@ -286,12 +287,10 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: call_results.append(function_call(tools[tool_id]["class"], parameters)) return call_results - call_results = [] - for i, result in enumerate(tool_results): - call_results.extend(parse_tool_results(result)) - tool_results[i]["call_results"] = call_results + call_results = parse_tool_results(tool_results) + tool_results["call_results"] = call_results - call_results_str = "\n\n".join([str(e) for e in call_results if e is not None]) + call_results_str = str(call_results) _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str @@ -335,7 +334,11 @@ def self_reflect( tool_results=str(tool_result), final_answer=final_answer, ) - if issubclass(type(reflect_model), LMM): + if ( + issubclass(type(reflect_model), LMM) + and image is not None + and Path(image).suffix in [".jpg", ".jpeg", ".png"] + ): return reflect_model(prompt, image=image) # type: ignore return reflect_model(prompt) @@ -345,6 +348,56 @@ def parse_reflect(reflect: str) -> bool: return "finish" in reflect.lower() and len(reflect) < 100 +def visualize_result(all_tool_results: List[Dict]) -> List[str]: + image_to_data = {} + for tool_result in all_tool_results: + if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]: + continue + + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + continue + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if ( + len(tool_result["parameters"]) < 1 + and "image" not in tool_result["parameters"][0] + ): + continue + + for param, call_result in zip(parameters, tool_result["call_results"]): + + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict): + continue + if not "bboxes" in call_result: + continue + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = {"bboxes": [], "masks": [], "labels": []} + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) + + visualized_images = [] + for image in image_to_data: + image_path = Path(image) + image_data = image_to_data[image] + image = overlay_masks(image_path, image_data) + image = overlay_bboxes(image, image_data) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + image.save(f.name) + visualized_images.append(f.name) + return visualized_images + + class VisionAgent(Agent): r"""Vision Agent is an agent framework that utilizes tools as well as self reflection to accomplish tasks, in particular vision tasks. Vision Agent is based @@ -389,7 +442,8 @@ def __call__( """Invoke the vision agent. Parameters: - input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}]. + input: a prompt that describe the task or a conversation in the format of + [{"role": "user", "content": "describe your task here..."}]. image: the input image referenced in the prompt parameter. Returns: @@ -436,9 +490,8 @@ def chat_with_workflow( self.answer_model, task_str, call_results, previous_log, reflections ) - for tool_result in tool_results: - tool_result["answer"] = answer - all_tool_results.extend(tool_results) + tool_results["answer"] = answer + all_tool_results.append(tool_results) _LOGGER.info(f"\tAnswer: {answer}") answers.append({"task": task_str, "answer": answer}) @@ -448,13 +501,14 @@ def chat_with_workflow( self.answer_model, question, answers, reflections ) + visualized_images = visualize_result(all_tool_results) reflection = self_reflect( self.reflect_model, question, self.tools, all_tool_results, final_answer, - image, + visualized_images[0] if len(visualized_images) > 0 else image, ) _LOGGER.info(f"\tReflection: {reflection}") if parse_reflect(reflection): diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 05a129ce..849f912f 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -3,15 +3,38 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Tuple, Union +from typing import Dict, Tuple, Union import numpy as np -from PIL import Image +from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType +COLORS = [ + (158, 218, 229), + (219, 219, 141), + (23, 190, 207), + (188, 189, 34), + (199, 199, 199), + (247, 182, 210), + (127, 127, 127), + (227, 119, 194), + (196, 156, 148), + (197, 176, 213), + (140, 86, 75), + (148, 103, 189), + (255, 152, 150), + (152, 223, 138), + (214, 39, 40), + (44, 160, 44), + (255, 187, 120), + (174, 199, 232), + (255, 127, 14), + (31, 119, 180), +] + def b64_to_pil(b64_str: str) -> ImageType: - """Convert a base64 string to a PIL Image. + r"""Convert a base64 string to a PIL Image. Parameters: b64_str: the base64 encoded image @@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType: def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]: - """Get the size of an image. + r"""Get the size of an image. Parameters: data: the input image @@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: - """Convert an image to a base64 string. + r"""Convert an image to a base64 string. Parameters: data: the input image @@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: else: arr_bytes = data.tobytes() return base64.b64encode(arr_bytes).decode("utf-8") + + +def overlay_bboxes( + image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict +) -> ImageType: + r"""Plots bounding boxes on to an image. + + Parameters: + image: the input image + bboxes: the bounding boxes to overlay + + Returns: + The image with the bounding boxes overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])} + + draw = ImageDraw.Draw(image) + font = ImageFont.load_default() + width, height = image.size + if "bboxes" not in bboxes: + return image.convert("RGB") + + for label, box in zip(bboxes["labels"], bboxes["bboxes"]): + box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height] + draw.rectangle(box, outline=color[label], width=3) + label = f"{label}" + text_box = draw.textbbox((box[0], box[1]), text=label, font=font) + draw.rectangle(text_box, fill=color[label]) + draw.text((text_box[0], text_box[1]), label, fill="black", font=font) + return image.convert("RGB") + + +def overlay_masks( + image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5 +) -> ImageType: + r"""Plots masks on to an image. + + Parameters: + image: the input image + masks: the masks to overlay + alpha: the transparency of the overlay + + Returns: + The image with the masks overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])} + if "masks" not in masks: + return image.convert("RGB") + + for label, mask in zip(masks["labels"], masks["masks"]): + if isinstance(mask, str): + mask = np.array(Image.open(mask)) + np_mask = np.zeros((image.size[1], image.size[0], 4)) + np_mask[mask > 0, :] = color[label] + (255 * alpha,) + mask_img = Image.fromarray(np_mask.astype(np.uint8)) + image = Image.alpha_composite(image.convert("RGBA"), mask_img) + return image.convert("RGB") From d9bdcf8b7a618fcc98b9cc0302fb07cb821d8ed4 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 29 Mar 2024 10:45:24 -0700 Subject: [PATCH 3/5] update tool return format --- vision_agent/tools/__init__.py | 15 +++- vision_agent/tools/tools.py | 144 ++++++++++++++++++--------------- vision_agent/tools/video.py | 12 ++- 3 files changed, 99 insertions(+), 72 deletions(-) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 7e5636c2..20dc503a 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,15 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool +from .tools import ( + CLIP, + TOOLS, + BboxArea, + BboxIoU, + Counter, + Crop, + ExtractFrames, + GroundingDINO, + GroundingSAM, + SegArea, + SegIoU, + Tool, +) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index b543fa9d..55bbd477 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -92,7 +92,7 @@ class CLIP(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: """Invoke the CLIP model. Parameters: @@ -122,7 +122,7 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict rets = [] for elt in resp_json["data"]: rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) - return cast(List[Dict], rets) + return cast(Dict, rets[0]) class GroundingDINO(Tool): @@ -168,7 +168,7 @@ class GroundingDINO(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict]: + def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> Dict: """Invoke the Grounding DINO model. Parameters: @@ -204,7 +204,7 @@ def __call__(self, prompt: str, image: Union[str, Path, ImageType]) -> List[Dict if "scores" in elt: elt["scores"] = [round(score, 2) for score in elt["scores"]] elt["size"] = (image_size[1], image_size[0]) - return cast(List[Dict], resp_data) + return cast(Dict, resp_data) class GroundingSAM(Tool): @@ -259,7 +259,7 @@ class GroundingSAM(Tool): } # TODO: Add support for input multiple images, which aligns with the output type. - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: """Invoke the Grounding SAM model. Parameters: @@ -294,7 +294,7 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict ret_pred["labels"].append(pred["label_name"]) ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) ret_pred["masks"].append(mask) - return [ret_pred] + return ret_pred class AgentGroundingSAM(GroundingSAM): @@ -302,15 +302,14 @@ class AgentGroundingSAM(GroundingSAM): returns the file name. This makes it easier for agents to use. """ - def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: rets = super().__call__(prompt, image) - for ret in rets: - mask_files = [] - for mask in ret["masks"]: - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: - Image.fromarray(mask * 255).save(tmp) - mask_files.append(tmp.name) - ret["masks"] = mask_files + mask_files = [] + for mask in rets["masks"]: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + rets["masks"] = mask_files return rets @@ -363,7 +362,7 @@ class Crop(Tool): ], } - def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: + def __call__(self, bbox: List[float], image: Union[str, Path]) -> Dict: pil_image = Image.open(image) width, height = pil_image.size bbox = [ @@ -373,10 +372,10 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: int(bbox[3] * height), ] cropped_image = pil_image.crop(bbox) # type: ignore - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: cropped_image.save(tmp.name) - return tmp.name + return {"image": tmp.name} class BboxArea(Tool): @@ -432,15 +431,23 @@ def __call__(self, masks: Union[str, Path]) -> float: class BboxIoU(Tool): name = "bbox_iou_" - description = "'bbox_iou_' returns the intersection over union of two bounding boxes." + description = ( + "'bbox_iou_' returns the intersection over union of two bounding boxes." + ) usage = { - "required_parameters": [{"name": "bbox1", "type": "List[int]"}, {"name": "bbox2", "type": "List[int]"}], + "required_parameters": [ + {"name": "bbox1", "type": "List[int]"}, + {"name": "bbox2", "type": "List[int]"}, + ], "examples": [ { "scenario": "If you want to calculate the intersection over union of the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", - "parameters": {"bbox1": [0.2, 0.21, 0.34, 0.42], "bbox2": [0.3, 0.31, 0.44, 0.52]}, + "parameters": { + "bbox1": [0.2, 0.21, 0.34, 0.42], + "bbox2": [0.3, 0.31, 0.44, 0.52], + }, } - ] + ], } def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: @@ -459,13 +466,16 @@ def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: class SegIoU(Tool): name = "seg_iou_" - description = "'seg_iou_' returns the intersection over union of two segmentation masks." + description = "'seg_iou_' returns the intersection over union of two segmentation masks given their segmentation mask files." usage = { - "required_parameters": [{"name": "mask1", "type": "str"}, {"name": "mask2", "type": "str"}], + "required_parameters": [ + {"name": "mask1", "type": "str"}, + {"name": "mask2", "type": "str"}, + ], "examples": [ { - "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask1.png and mask2.png", - "parameters": {"mask1": "mask1.png", "mask2": "mask2.png"}, + "scenario": "If you want to calculate the intersection over union of the segmentation masks for mask_file1.jpg and mask_file2.jpg", + "parameters": {"mask1": "mask_file1.png", "mask2": "mask_file2.png"}, } ], } @@ -481,6 +491,47 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: return round(iou, 2) +class ExtractFrames(Tool): + r"""Extract frames from a video.""" + + name = "extract_frames_" + description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path." + usage = { + "required_parameters": [{"name": "video_uri", "type": "str"}], + "examples": [ + { + "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4", + "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"}, + }, + { + "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4", + "parameters": {"video_uri": "tests/data/test.mp4"}, + }, + ], + } + + def __call__(self, video_uri: str) -> List[Tuple[str, float]]: + """Extract frames from a video. + + + Parameters: + video_uri: the path to the video file or a url points to the video data + + Returns: + a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. + """ + frames = extract_frames_from_video(video_uri) + result = [] + _LOGGER.info( + f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks." + ) + for frame, ts in frames: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(frame).save(tmp) + result.append((tmp.name, ts)) + return result + + class Add(Tool): r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places.""" @@ -557,47 +608,6 @@ def __call__(self, input: List[int]) -> float: return round(input[0] / input[1], 2) -class ExtractFrames(Tool): - r"""Extract frames from a video.""" - - name = "extract_frames_" - description = "'extract_frames_' extract image frames from the input video, return a list of tuple (frame, timestamp), where the timestamp is the relative time in seconds of the frame occurred in the video, the frame is a local image file path that stores the frame." - usage = { - "required_parameters": [{"name": "video_uri", "type": "str"}], - "examples": [ - { - "scenario": "Can you extract the frames from this video? Video: www.foobar.com/video?name=test.mp4", - "parameters": {"video_uri": "www.foobar.com/video?name=test.mp4"}, - }, - { - "scenario": "Can you extract the images from this video file? Video path: tests/data/test.mp4", - "parameters": {"video_uri": "tests/data/test.mp4"}, - }, - ], - } - - def __call__(self, video_uri: str) -> list[tuple[str, float]]: - """Extract frames from a video. - - - Parameters: - video_uri: the path to the video file or a url points to the video data - - Returns: - a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(path_to_frame1, 0.0), (path_to_frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. - """ - frames = extract_frames_from_video(video_uri) - result = [] - _LOGGER.info( - f"Extracted {len(frames)} frames from video {video_uri}. Temporarily saving them as images to disk for downstream tasks." - ) - for frame, ts in frames: - with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: - Image.fromarray(frame).save(tmp) - result.append((tmp.name, ts)) - return result - - TOOLS = { i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c} for i, c in enumerate( @@ -605,6 +615,7 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]: CLIP, GroundingDINO, AgentGroundingSAM, + ExtractFrames, Counter, Crop, BboxArea, @@ -615,7 +626,6 @@ def __call__(self, video_uri: str) -> list[tuple[str, float]]: Subtract, Multiply, Divide, - ExtractFrames, ] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) diff --git a/vision_agent/tools/video.py b/vision_agent/tools/video.py index 606166ed..6068725f 100644 --- a/vision_agent/tools/video.py +++ b/vision_agent/tools/video.py @@ -22,12 +22,16 @@ def extract_frames_from_video( Parameters: video_uri: the path to the video file or a video file url fps: the frame rate per second to extract the frames - motion_detection_threshold: The threshold to detect motion between changes/frames. - A value between 0-1, which represents the percentage change required for the frames to be considered in motion. - For example, a lower value means more frames will be extracted. + motion_detection_threshold: The threshold to detect motion between + changes/frames. A value between 0-1, which represents the percentage change + required for the frames to be considered in motion. For example, a lower + value means more frames will be extracted. Returns: - a list of tuples containing the extracted frame and the timestamp in seconds. E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds from the start of the video. E.g. 12.125 means 12.125 seconds from the start of the video. The frames are sorted by the timestamp in ascending order. + a list of tuples containing the extracted frame and the timestamp in seconds. + E.g. [(frame1, 0.0), (frame2, 0.5), ...]. The timestamp is the time in seconds + from the start of the video. E.g. 12.125 means 12.125 seconds from the start of + the video. The frames are sorted by the timestamp in ascending order. """ with VideoFileClip(video_uri) as video: video_duration: float = video.duration From 1872e13ddbcef1891c1a3090408763aad3e1edbc Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 29 Mar 2024 10:47:57 -0700 Subject: [PATCH 4/5] typing and flake8 issues --- vision_agent/agent/vision_agent.py | 5 ++--- vision_agent/tools/tools.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index caba0533..d02ee0a9 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -2,7 +2,6 @@ import logging import sys import tempfile -from os import walk from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -349,7 +348,7 @@ def parse_reflect(reflect: str) -> bool: def visualize_result(all_tool_results: List[Dict]) -> List[str]: - image_to_data = {} + image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]: continue @@ -373,7 +372,7 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: # calls can fail, so we need to check if the call was successful if not isinstance(call_result, dict): continue - if not "bboxes" in call_result: + if "bboxes" not in call_result: continue # if the call was successful, then we can add the image data diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 55bbd477..f13c14dd 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -488,7 +488,7 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: intersection = np.logical_and(np_mask1, np_mask2) union = np.logical_or(np_mask1, np_mask2) iou = np.sum(intersection) / np.sum(union) - return round(iou, 2) + return cast(float, round(iou, 2)) class ExtractFrames(Tool): From 8e68cff15a4a85214b65f9c6b8312c3308c66810 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 29 Mar 2024 11:20:48 -0700 Subject: [PATCH 5/5] added visualized images to all_tool_results --- vision_agent/agent/vision_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d02ee0a9..e9e6d66d 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -501,6 +501,7 @@ def chat_with_workflow( ) visualized_images = visualize_result(all_tool_results) + all_tool_results.append({"visualized_images": visualized_images}) reflection = self_reflect( self.reflect_model, question,