diff --git a/README.md b/README.md index bff0a12a..741b8ff2 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,8 @@ the individual steps and tools to get the answer: } ]], "answer": "The jar is located at [0.58, 0.2, 0.72, 0.45].", -}] +}, +{"visualize_output": "final_output.png"}] ``` ### Tools diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 95ae444b..2f1d58b4 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from PIL import Image from tabulate import tabulate from vision_agent.image_utils import overlay_bboxes, overlay_masks @@ -288,9 +289,8 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: continue parameters = [parameters] elif isinstance(tool_result["parameters"], list): - if ( - len(tool_result["parameters"]) < 1 - and "image" not in tool_result["parameters"][0] + if len(tool_result["parameters"]) < 1 or ( + "image" not in tool_result["parameters"][0] ): continue @@ -304,10 +304,16 @@ def visualize_result(all_tool_results: List[Dict]) -> List[str]: # if the call was successful, then we can add the image data image = param["image"] if image not in image_to_data: - image_to_data[image] = {"bboxes": [], "masks": [], "labels": []} + image_to_data[image] = { + "bboxes": [], + "masks": [], + "labels": [], + "scores": [], + } image_to_data[image]["bboxes"].extend(call_result["bboxes"]) image_to_data[image]["labels"].extend(call_result["labels"]) + image_to_data[image]["scores"].extend(call_result["scores"]) if "masks" in call_result: image_to_data[image]["masks"].extend(call_result["masks"]) @@ -345,7 +351,7 @@ def __init__( task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, reflect_model: Optional[Union[LLM, LMM]] = None, - max_retries: int = 2, + max_retries: int = 3, verbose: bool = False, report_progress_callback: Optional[Callable[[str], None]] = None, ): @@ -380,6 +386,7 @@ def __call__( self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None, + visualize_output: Optional[bool] = False, ) -> str: """Invoke the vision agent. @@ -393,7 +400,7 @@ def __call__( """ if isinstance(input, str): input = [{"role": "user", "content": input}] - return self.chat(input, image=image) + return self.chat(input, image=image, visualize_output=visualize_output) def log_progress(self, description: str) -> None: _LOGGER.info(description) @@ -401,7 +408,10 @@ def log_progress(self, description: str) -> None: self.report_progress_callback(description) def chat_with_workflow( - self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + self, + chat: List[Dict[str, str]], + image: Optional[Union[str, Path]] = None, + visualize_output: Optional[bool] = False, ) -> Tuple[str, List[Dict]]: question = chat[0]["content"] if image: @@ -449,31 +459,42 @@ def chat_with_workflow( self.answer_model, question, answers, reflections ) - visualized_images = visualize_result(all_tool_results) - all_tool_results.append({"visualized_images": visualized_images}) + visualized_output = visualize_result(all_tool_results) + all_tool_results.append({"visualized_output": visualized_output}) reflection = self_reflect( self.reflect_model, question, self.tools, all_tool_results, final_answer, - visualized_images[0] if len(visualized_images) > 0 else image, + visualized_output[0] if len(visualized_output) > 0 else image, ) self.log_progress(f"Reflection: {reflection}") if parse_reflect(reflection): break else: - reflections += reflection - # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. + reflections += "\n" + reflection + # '' is a symbol to indicate the end of the chat, which is useful for streaming logs. self.log_progress( - f"The Vision Agent has concluded this chat. {final_answer}" + f"The Vision Agent has concluded this chat. {final_answer}" ) + + if visualize_output: + visualized_output = all_tool_results[-1]["visualized_output"] + for image in visualized_output: + Image.open(image).show() + return final_answer, all_tool_results def chat( - self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + self, + chat: List[Dict[str, str]], + image: Optional[Union[str, Path]] = None, + visualize_output: Optional[bool] = False, ) -> str: - answer, _ = self.chat_with_workflow(chat, image=image) + answer, _ = self.chat_with_workflow( + chat, image=image, visualize_output=visualize_output + ) return answer def retrieval( diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index f451a66e..a54ae6a6 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -1,4 +1,4 @@ -VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure with the tools available. Use complete sentences. +VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences. User's question: {question} @@ -49,7 +49,6 @@ CHOOSE_TOOL = """This is the user's question: {question} These are the tools you can select to solve the question: - {tools} Please note that: @@ -63,7 +62,6 @@ CHOOSE_TOOL_DEPENDS = """This is the user's question: {question} These are the tools you can select to solve the question: - {tools} This is a reflection from a previous failed attempt: diff --git a/vision_agent/fonts/__init__.py b/vision_agent/fonts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vision_agent/fonts/default_font_ch_en.ttf b/vision_agent/fonts/default_font_ch_en.ttf new file mode 100644 index 00000000..0a52b108 Binary files /dev/null and b/vision_agent/fonts/default_font_ch_en.ttf differ diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 65ee5b01..d0164e11 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -1,6 +1,7 @@ """Utility functions for image processing.""" import base64 +from importlib import resources from io import BytesIO from pathlib import Path from typing import Dict, Tuple, Union @@ -104,19 +105,28 @@ def overlay_bboxes( color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])} - draw = ImageDraw.Draw(image) - font = ImageFont.load_default() width, height = image.size + fontsize = max(12, int(min(width, height) / 40)) + draw = ImageDraw.Draw(image) + font = ImageFont.truetype( + str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), + fontsize, + ) if "bboxes" not in bboxes: return image.convert("RGB") - for label, box in zip(bboxes["labels"], bboxes["bboxes"]): - box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height] - draw.rectangle(box, outline=color[label], width=3) - label = f"{label}" - text_box = draw.textbbox((box[0], box[1]), text=label, font=font) - draw.rectangle(text_box, fill=color[label]) - draw.text((text_box[0], text_box[1]), label, fill="black", font=font) + for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]): + box = [ + int(box[0] * width), + int(box[1] * height), + int(box[2] * width), + int(box[3] * height), + ] + draw.rectangle(box, outline=color[label], width=4) + text = f"{label}: {scores:.2f}" + text_box = draw.textbbox((box[0], box[1]), text=text, font=font) + draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label]) + draw.text((box[0], box[1]), text, fill="black", font=font) return image.convert("RGB") @@ -138,7 +148,9 @@ def overlay_masks( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])} + color = { + label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"])) + } if "masks" not in masks: return image.convert("RGB") diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 700555cc..12450753 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -53,9 +53,7 @@ class Tool(ABC): class NoOp(Tool): name = "noop_" - description = ( - "'noop_' is a no-op tool that does nothing if you do not need to use a tool." - ) + description = "'noop_' is a no-op tool that does nothing if you do not want answer the question directly and not use a tool." usage = { "required_parameters": [], "examples": [ @@ -85,7 +83,7 @@ class CLIP(Tool): _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "clip_" - description = "'clip_' is a tool that can classify or tag any image given a set of input classes or tags." + description = "'clip_' is a tool that can classify any image given a set of input names or tags. It returns a list of the input names along with their probability scores." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -163,7 +161,7 @@ class GroundingDINO(Tool): _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "grounding_dino_" - description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." + description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and associated probability scores." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -179,8 +177,11 @@ class GroundingDINO(Tool): "parameters": {"prompt": "car", "image": ""}, }, { - "scenario": "Can you detect the person on the left? Image name: person.jpg", - "parameters": {"prompt": "person on the left", "image": "person.jpg"}, + "scenario": "Can you detect the person on the left and right? Image name: person.jpg", + "parameters": { + "prompt": "left person. right person", + "image": "person.jpg", + }, }, { "scenario": "Detect the red shirts and green shirst. Image name: shirts.jpg", @@ -269,7 +270,7 @@ class GroundingSAM(Tool): _ENDPOINT = "https://soi4ewr6fjqqdf5vuss6rrilee0kumxq.lambda-url.us-east-2.on.aws" name = "grounding_sam_" - description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." + description = "'grounding_sam_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. It returns a list of bounding boxes, label names and masks file names and associated probability scores." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -285,8 +286,11 @@ class GroundingSAM(Tool): "parameters": {"prompt": "car", "image": ""}, }, { - "scenario": "Can you segment the person on the left? Image name: person.jpg", - "parameters": {"prompt": "person on the left", "image": "person.jpg"}, + "scenario": "Can you segment the person on the left and right? Image name: person.jpg", + "parameters": { + "prompt": "left person. right person", + "image": "person.jpg", + }, }, { "scenario": "Can you build me a tool that segments red shirts and green shirts? Image name: shirts.jpg", @@ -370,8 +374,9 @@ def __call__( mask_files = [] for mask in rets["masks"]: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - Image.fromarray(mask * 255).save(tmp) - mask_files.append(tmp.name) + file_name = Path(tmp.name).with_suffix(".mask.png") + Image.fromarray(mask * 255).save(file_name) + mask_files.append(str(file_name)) rets["masks"] = mask_files return rets @@ -380,7 +385,7 @@ class Counter(Tool): r"""Counter detects and counts the number of objects in an image given an input such as a category name or referring expression.""" name = "counter_" - description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression." + description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression. It returns a dictionary containing the labels and their counts." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -400,14 +405,14 @@ class Counter(Tool): def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: resp = GroundingDINO()(prompt, image) - return dict(CounterClass(resp[0]["labels"])) + return dict(CounterClass(resp["labels"])) class Crop(Tool): r"""Crop crops an image given a bounding box and returns a file name of the cropped image.""" name = "crop_" - description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image." + description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image. It returns a file with the cropped image." usage = { "required_parameters": [ {"name": "bbox", "type": "List[float]"}, @@ -495,9 +500,7 @@ def __call__(self, masks: Union[str, Path]) -> float: class BboxIoU(Tool): name = "bbox_iou_" - description = ( - "'bbox_iou_' returns the intersection over union of two bounding boxes." - ) + description = "'bbox_iou_' returns the intersection over union of two bounding boxes. This is a good tool for determining if two objects are overlapping." usage = { "required_parameters": [ {"name": "bbox1", "type": "List[int]"}, @@ -591,85 +594,35 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]: ) for frame, ts in frames: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - Image.fromarray(frame).save(tmp) - result.append((tmp.name, ts)) + file_name = Path(tmp.name).with_suffix(".frame.png") + Image.fromarray(frame).save(file_name) + result.append((str(file_name), ts)) return result -class Add(Tool): - r"""Add returns the sum of all the arguments passed to it, normalized to 2 decimal places.""" - - name = "add_" - description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." - usage = { - "required_parameters": [{"name": "input", "type": "List[int]"}], - "examples": [ - { - "scenario": "If you want to calculate 2 + 4", - "parameters": {"input": [2, 4]}, - } - ], - } - - def __call__(self, input: List[int]) -> float: - return round(sum(input), 2) - - -class Subtract(Tool): - r"""Subtract returns the difference of all the arguments passed to it, normalized to 2 decimal places.""" - - name = "subtract_" - description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places." - usage = { - "required_parameters": [{"name": "input", "type": "List[int]"}], - "examples": [ - { - "scenario": "If you want to calculate 4 - 2", - "parameters": {"input": [4, 2]}, - } - ], - } - - def __call__(self, input: List[int]) -> float: - return round(input[0] - input[1], 2) - +class Calculator(Tool): + r"""Calculator is a tool that can perform basic arithmetic operations.""" -class Multiply(Tool): - r"""Multiply returns the product of all the arguments passed to it, normalized to 2 decimal places.""" - - name = "multiply_" - description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places." + name = "calculator_" + description = ( + "'calculator_' is a tool that can perform basic arithmetic operations." + ) usage = { - "required_parameters": [{"name": "input", "type": "List[int]"}], + "required_parameters": [{"name": "equation", "type": "str"}], "examples": [ { - "scenario": "If you want to calculate 2 * 4", - "parameters": {"input": [2, 4]}, - } - ], - } - - def __call__(self, input: List[int]) -> float: - return round(input[0] * input[1], 2) - - -class Divide(Tool): - r"""Divide returns the division of all the arguments passed to it, normalized to 2 decimal places.""" - - name = "divide_" - description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places." - usage = { - "required_parameters": [{"name": "input", "type": "List[int]"}], - "examples": [ + "scenario": "If you want to calculate (2 * 3) + 4", + "parameters": {"equation": "2 + 4"}, + }, { - "scenario": "If you want to calculate 4 / 2", - "parameters": {"input": [4, 2]}, - } + "scenario": "If you want to calculate (4 + 2.5) / 2.1", + "parameters": {"equation": "(4 + 2.5) / 2.1"}, + }, ], } - def __call__(self, input: List[int]) -> float: - return round(input[0] / input[1], 2) + def __call__(self, equation: str) -> float: + return cast(float, round(eval(equation), 2)) TOOLS = { @@ -687,10 +640,7 @@ def __call__(self, input: List[int]) -> float: SegArea, BboxIoU, SegIoU, - Add, - Subtract, - Multiply, - Divide, + Calculator, ] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))