From 8415cb306f40e0d1bb12a06e59df07aa4bb306fc Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Wed, 17 Apr 2024 14:51:50 -0700 Subject: [PATCH] feat: allow disable motion detection in frame extraction function (#55) * Tweak frame extraction function * remove default motion detection, extract at 0.5 fps * lmm now take multiple images * removed counter * tweaked prompt * updated vision agent to reflect on multiple images * fix test case * added box distance * adjusted prompts --------- Co-authored-by: Yazhou Cao Co-authored-by: Dillon Laird --- tests/test_lmm.py | 2 +- tests/tools/test_tools.py | 25 ++++- vision_agent/agent/easytool_prompts.py | 10 +- vision_agent/agent/reflexion.py | 20 ++-- vision_agent/agent/vision_agent.py | 22 +++-- vision_agent/agent/vision_agent_prompts.py | 9 +- vision_agent/data/data.py | 4 +- vision_agent/lmm/lmm.py | 102 ++++++++++++--------- vision_agent/tools/__init__.py | 4 +- vision_agent/tools/tools.py | 63 +++++++------ vision_agent/tools/video.py | 26 +++--- 11 files changed, 176 insertions(+), 111 deletions(-) diff --git a/tests/test_lmm.py b/tests/test_lmm.py index 16b5691f..876678b9 100644 --- a/tests/test_lmm.py +++ b/tests/test_lmm.py @@ -27,7 +27,7 @@ def create_temp_image(image_format="jpeg"): def test_generate_with_mock(openai_lmm_mock): # noqa: F811 temp_image = create_temp_image() lmm = OpenAILMM() - response = lmm.generate("test prompt", image=temp_image) + response = lmm.generate("test prompt", images=[temp_image]) assert response == "mocked response" assert ( "image_url" diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index c5d7c6cc..12c21347 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -4,7 +4,7 @@ import numpy as np from PIL import Image -from vision_agent.tools.tools import BboxIoU, SegArea, SegIoU +from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU def test_bbox_iou(): @@ -42,3 +42,26 @@ def test_seg_area_2(): mask_path = os.path.join(tmpdir, "mask.png") Image.fromarray(mask).save(mask_path) assert SegArea()(mask_path) == 4.0 + + +def test_box_distance(): + box_dist = BoxDistance() + # horizontal dist + box1 = [0, 0, 2, 2] + box2 = [4, 1, 6, 3] + assert box_dist(box1, box2) == 2.0 + + # vertical dist + box1 = [0, 0, 2, 2] + box2 = [1, 4, 3, 6] + assert box_dist(box1, box2) == 2.0 + + # vertical and horizontal + box1 = [0, 0, 2, 2] + box2 = [3, 3, 5, 5] + assert box_dist(box1, box2) == 1.41 + + # overlap + box1 = [0, 0, 2, 2] + box2 = [1, 1, 3, 3] + assert box_dist(box1, box2) == 0.0 diff --git a/vision_agent/agent/easytool_prompts.py b/vision_agent/agent/easytool_prompts.py index 6e20dc17..73045ba8 100644 --- a/vision_agent/agent/easytool_prompts.py +++ b/vision_agent/agent/easytool_prompts.py @@ -56,6 +56,7 @@ These are logs of previous questions and answers: {previous_log} + This is the current user's question: {question} This is the API tool documentation: {tool_usage} Output: """ @@ -67,15 +68,22 @@ 2. We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible. 3. If the API tool does not provide useful information in the response, please answer with your knowledge. 4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers. + These are logs of previous questions and answers: {previous_log} + This is the user's question: {question} + This is the response output by the API tool: {call_results} + We will not show the API response to the user, thus you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible. Output: """ ANSWER_SUMMARIZE = """We break down a complex user's problems into simple subtasks and provide answers to each simple subtask. You need to organize these answers to each subtask and form a self-consistent final answer to the user's question. This is the user's question: {question} -These are subtasks and their answers: {answers} + +These are subtasks and their answers: +{answers} + Final answer: """ diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index ac7d77b6..61dded6d 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -238,12 +238,20 @@ def prompt_agent( self._build_agent_prompt(question, reflections, scratchpad) ) ) - return format_step( - self.action_agent( - self._build_agent_prompt(question, reflections, scratchpad), - image=image, + elif isinstance(self.action_agent, LMM): + return format_step( + self.action_agent( + self._build_agent_prompt(question, reflections, scratchpad), + images=[image] if image is not None else None, + ) + ) + elif isinstance(self.action_agent, Agent): + return format_step( + self.action_agent( + self._build_agent_prompt(question, reflections, scratchpad), + image=image, + ) ) - ) def prompt_reflection( self, @@ -261,7 +269,7 @@ def prompt_reflection( return format_step( self.self_reflect_model( self._build_reflect_prompt(question, context, scratchpad), - image=image, + images=[image] if image is not None else None, ) ) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 4c193aae..a3f09b82 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -3,7 +3,7 @@ import sys import tempfile from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from PIL import Image from tabulate import tabulate @@ -264,7 +264,7 @@ def self_reflect( tools: Dict[int, Any], tool_result: List[Dict], final_answer: str, - image: Optional[Union[str, Path]] = None, + images: Optional[Sequence[Union[str, Path]]] = None, ) -> str: prompt = VISION_AGENT_REFLECTION.format( question=question, @@ -275,10 +275,10 @@ def self_reflect( ) if ( issubclass(type(reflect_model), LMM) - and image is not None - and Path(image).suffix in [".jpg", ".jpeg", ".png"] + and images is not None + and all([Path(image).suffix in [".jpg", ".jpeg", ".png"] for image in images]) ): - return reflect_model(prompt, image=image) # type: ignore + return reflect_model(prompt, images=images) # type: ignore return reflect_model(prompt) @@ -357,7 +357,7 @@ def _handle_viz_tools( return image_to_data -def visualize_result(all_tool_results: List[Dict]) -> List[str]: +def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]: image_to_data: Dict[str, Dict] = {} for tool_result in all_tool_results: # only handle bbox/mask tools or frame extraction @@ -407,7 +407,7 @@ def __init__( task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, reflect_model: Optional[Union[LLM, LMM]] = None, - max_retries: int = 2, + max_retries: int = 3, verbose: bool = False, report_progress_callback: Optional[Callable[[str], None]] = None, ): @@ -519,13 +519,19 @@ def chat_with_workflow( visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) + if len(visualized_output) > 0: + reflection_images = visualized_output + elif image is not None: + reflection_images = [image] + else: + reflection_images = None reflection = self_reflect( self.reflect_model, question, self.tools, all_tool_results, final_answer, - visualized_output[0] if len(visualized_output) > 0 else image, + reflection_images, ) self.log_progress(f"Reflection: {reflection}") parsed_reflection = parse_reflect(reflection) diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index cd7878c0..4cc65845 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -1,11 +1,11 @@ -VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. +VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question, the tool usage for each of the tools used and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. Please note that: 1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like: {{"Finish": true, "Reflection": "The agent's answer was correct."}} -2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly. -3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: - {{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters: +2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or if the tools were used incorrectly or the wrong tools were used. +3. If the agent's answer was incorrect, you must diagnose the reason for failure and devise a new concise and concrete plan that aims to mitigate the same failure with the tools available. An example output looks like: + {{"Finish": false, "Reflection": "I can see from the visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters: Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives. Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}} 4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true. @@ -140,4 +140,5 @@ This is a reflection from a previous failed attempt: {reflections} + Final answer: """ diff --git a/vision_agent/data/data.py b/vision_agent/data/data.py index 6b51488b..54125772 100644 --- a/vision_agent/data/data.py +++ b/vision_agent/data/data.py @@ -63,9 +63,9 @@ def add_column( self.df[name] = self.df["image_paths"].progress_apply( # type: ignore lambda x: ( - func(self.lmm.generate(prompt, image=x)) + func(self.lmm.generate(prompt, images=[x])) if func - else self.lmm.generate(prompt, image=x) + else self.lmm.generate(prompt, images=[x]) ) ) return self diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 738ac004..615804ed 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -30,12 +30,16 @@ def encode_image(image: Union[str, Path]) -> str: class LMM(ABC): @abstractmethod - def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: + def generate( + self, prompt: str, images: Optional[List[Union[str, Path]]] = None + ) -> str: pass @abstractmethod def chat( - self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + self, + chat: List[Dict[str, str]], + images: Optional[List[Union[str, Path]]] = None, ) -> str: pass @@ -43,7 +47,7 @@ def chat( def __call__( self, input: Union[str, List[Dict[str, str]]], - image: Optional[Union[str, Path]] = None, + images: Optional[List[Union[str, Path]]] = None, ) -> str: pass @@ -57,27 +61,29 @@ def __init__(self, model_name: str): def __call__( self, input: Union[str, List[Dict[str, str]]], - image: Optional[Union[str, Path]] = None, + images: Optional[List[Union[str, Path]]] = None, ) -> str: if isinstance(input, str): - return self.generate(input, image) - return self.chat(input, image) + return self.generate(input, images) + return self.chat(input, images) def chat( - self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + self, + chat: List[Dict[str, str]], + images: Optional[List[Union[str, Path]]] = None, ) -> str: raise NotImplementedError("Chat not supported for LLaVA") def generate( self, prompt: str, - image: Optional[Union[str, Path]] = None, + images: Optional[List[Union[str, Path]]] = None, temperature: float = 0.1, max_new_tokens: int = 1500, ) -> str: data = {"prompt": prompt} - if image: - data["image"] = encode_image(image) + if images and len(images) > 0: + data["image"] = encode_image(images[0]) data["temperature"] = temperature # type: ignore data["max_new_tokens"] = max_new_tokens # type: ignore res = requests.post( @@ -121,14 +127,16 @@ def __init__( def __call__( self, input: Union[str, List[Dict[str, str]]], - image: Optional[Union[str, Path]] = None, + images: Optional[List[Union[str, Path]]] = None, ) -> str: if isinstance(input, str): - return self.generate(input, image) - return self.chat(input, image) + return self.generate(input, images) + return self.chat(input, images) def chat( - self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + self, + chat: List[Dict[str, str]], + images: Optional[List[Union[str, Path]]] = None, ) -> str: fixed_chat = [] for c in chat: @@ -136,25 +144,26 @@ def chat( fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore fixed_chat.append(fixed_c) - if image: - extension = Path(image).suffix - if extension.lower() == ".jpeg" or extension.lower() == ".jpg": - extension = "jpg" - elif extension.lower() == ".png": - extension = "png" - else: - raise ValueError(f"Unsupported image extension: {extension}") - - encoded_image = encode_image(image) - fixed_chat[0]["content"].append( # type: ignore - { - "type": "image_url", - "image_url": { - "url": f"data:image/{extension};base64,{encoded_image}", - "detail": "low", + if images and len(images) > 0: + for image in images: + extension = Path(image).suffix + if extension.lower() == ".jpeg" or extension.lower() == ".jpg": + extension = "jpg" + elif extension.lower() == ".png": + extension = "png" + else: + raise ValueError(f"Unsupported image extension: {extension}") + + encoded_image = encode_image(image) + fixed_chat[0]["content"].append( # type: ignore + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{encoded_image}", + "detail": "low", + }, }, - }, - ) + ) response = self.client.chat.completions.create( model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore @@ -162,7 +171,11 @@ def chat( return cast(str, response.choices[0].message.content) - def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: + def generate( + self, + prompt: str, + images: Optional[List[Union[str, Path]]] = None, + ) -> str: message: List[Dict[str, Any]] = [ { "role": "user", @@ -171,18 +184,19 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ], } ] - if image: - extension = Path(image).suffix - encoded_image = encode_image(image) - message[0]["content"].append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/{extension};base64,{encoded_image}", - "detail": "low", + if images and len(images) > 0: + for image in images: + extension = Path(image).suffix + encoded_image = encode_image(image) + message[0]["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{encoded_image}", + "detail": "low", + }, }, - }, - ) + ) response = self.client.chat.completions.create( model=self.model_name, messages=message, **self.kwargs # type: ignore diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index d2dea3e2..63931c9f 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,10 +1,10 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import ( +from .tools import ( # Counter, CLIP, TOOLS, BboxArea, BboxIoU, - Counter, + BoxDistance, Crop, ExtractFrames, GroundingDINO, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 792f1ba1..3a5c8a4f 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,7 +1,6 @@ import logging import tempfile from abc import ABC -from collections import Counter as CounterClass from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast @@ -396,33 +395,6 @@ def __call__( return rets -class Counter(Tool): - r"""Counter detects and counts the number of objects in an image given an input such as a category name or referring expression.""" - - name = "counter_" - description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression. It returns a dictionary containing the labels and their counts." - usage = { - "required_parameters": [ - {"name": "prompt", "type": "str"}, - {"name": "image", "type": "str"}, - ], - "examples": [ - { - "scenario": "Can you count the number of cars in this image? Image name image.jpg", - "parameters": {"prompt": "car", "image": "image.jpg"}, - }, - { - "scenario": "Can you count the number of people? Image name: people.png", - "parameters": {"prompt": "person", "image": "people.png"}, - }, - ], - } - - def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: - resp = GroundingDINO()(prompt, image) - return dict(CounterClass(resp["labels"])) - - class Crop(Tool): r"""Crop crops an image given a bounding box and returns a file name of the cropped image.""" @@ -573,11 +545,42 @@ def __call__(self, mask1: Union[str, Path], mask2: Union[str, Path]) -> float: return cast(float, round(iou, 2)) +class BoxDistance(Tool): + name = "box_distance_" + description = ( + "'box_distance_' returns the minimum distance between two bounding boxes." + ) + usage = { + "required_parameters": [ + {"name": "bbox1", "type": "List[int]"}, + {"name": "bbox2", "type": "List[int]"}, + ], + "examples": [ + { + "scenario": "If you want to calculate the distance between the bounding boxes [0.2, 0.21, 0.34, 0.42] and [0.3, 0.31, 0.44, 0.52]", + "parameters": { + "bbox1": [0.2, 0.21, 0.34, 0.42], + "bbox2": [0.3, 0.31, 0.44, 0.52], + }, + } + ], + } + + def __call__(self, bbox1: List[int], bbox2: List[int]) -> float: + x11, y11, x12, y12 = bbox1 + x21, y21, x22, y22 = bbox2 + + horizontal_dist = np.max([0, x21 - x12, x11 - x22]) + vertical_dist = np.max([0, y21 - y12, y11 - y22]) + + return cast(float, round(np.sqrt(horizontal_dist**2 + vertical_dist**2), 2)) + + class ExtractFrames(Tool): r"""Extract frames from a video.""" name = "extract_frames_" - description = "'extract_frames_' extracts frames where there is motion detected in a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where teh frame was captured. The frame is a local image file path." + description = "'extract_frames_' extracts frames from a video, returns a list of tuples (frame, timestamp), where timestamp is the relative time in seconds where the frame was captured. The frame is a local image file path." usage = { "required_parameters": [{"name": "video_uri", "type": "str"}], "examples": [ @@ -650,12 +653,12 @@ def __call__(self, equation: str) -> float: GroundingDINO, AgentGroundingSAM, ExtractFrames, - Counter, Crop, BboxArea, SegArea, BboxIoU, SegIoU, + BoxDistance, Calculator, ] ) diff --git a/vision_agent/tools/video.py b/vision_agent/tools/video.py index 6068725f..4eca66af 100644 --- a/vision_agent/tools/video.py +++ b/vision_agent/tools/video.py @@ -15,7 +15,7 @@ def extract_frames_from_video( - video_uri: str, fps: int = 2, motion_detection_threshold: float = 0.06 + video_uri: str, fps: float = 0.5, motion_detection_threshold: float = 0.0 ) -> List[Tuple[np.ndarray, float]]: """Extract frames from a video @@ -25,7 +25,8 @@ def extract_frames_from_video( motion_detection_threshold: The threshold to detect motion between changes/frames. A value between 0-1, which represents the percentage change required for the frames to be considered in motion. For example, a lower - value means more frames will be extracted. + value means more frames will be extracted. A non-positive value will disable + motion detection and extract all frames. Returns: a list of tuples containing the extracted frame and the timestamp in seconds. @@ -119,18 +120,19 @@ def _extract_frames_by_clip( total=processable_frames, desc=f"Extracting frames from clip {start}-{end}" ) for i, frame in enumerate(clip.iter_frames(fps=fps, dtype="uint8")): - curr_processed_frame = _preprocess_frame(frame) total_count += 1 pbar.update(1) - # Skip the frame if it is similar to the previous one - if prev_processed_frame is not None and _similar_frame( - prev_processed_frame, - curr_processed_frame, - threshold=motion_detection_threshold, - ): - skipped_count += 1 - continue - prev_processed_frame = curr_processed_frame + if motion_detection_threshold > 0: + curr_processed_frame = _preprocess_frame(frame) + # Skip the frame if it is similar to the previous one + if prev_processed_frame is not None and _similar_frame( + prev_processed_frame, + curr_processed_frame, + threshold=motion_detection_threshold, + ): + skipped_count += 1 + continue + prev_processed_frame = curr_processed_frame ts = round(clip.reader.pos / source_fps, 3) frames.append((frame, ts))