From 0507f6af111a8a0647cdc8c9a152be2eb668f2f6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 29 Mar 2024 10:44:14 -0700 Subject: [PATCH] add image visualization for reflection --- tests/{ => tools}/test_tools.py | 0 vision_agent/agent/vision_agent.py | 92 ++++++++++++++++++++------ vision_agent/image_utils.py | 100 +++++++++++++++++++++++++++-- 3 files changed, 168 insertions(+), 24 deletions(-) rename tests/{ => tools}/test_tools.py (100%) diff --git a/tests/test_tools.py b/tests/tools/test_tools.py similarity index 100% rename from tests/test_tools.py rename to tests/tools/test_tools.py diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 5d34fe9e..caba0533 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,11 +1,14 @@ import json import logging import sys +import tempfile +from os import walk from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from tabulate import tabulate +from vision_agent.image_utils import overlay_bboxes, overlay_masks from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -248,12 +251,12 @@ def retrieval( tools: Dict[int, Any], previous_log: str, reflections: str, -) -> Tuple[List[Dict], str]: +) -> Tuple[Dict, str]: tool_id = choose_tool( model, question, {k: v["description"] for k, v in tools.items()}, reflections ) if tool_id is None: - return [{}], "" + return {}, "" _LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})") tool_instructions = tools[tool_id] @@ -265,14 +268,12 @@ def retrieval( ) _LOGGER.info(f"\tParameters: {parameters} for {tool_name}") if parameters is None: - return [{}], "" - tool_results = [ - {"task": question, "tool_name": tool_name, "parameters": parameters} - ] + return {}, "" + tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters} _LOGGER.info( - f"""Going to run the following {len(tool_results)} tool(s) in sequence: -{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}""" + f"""Going to run the following tool(s) in sequence: +{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}""" ) def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: @@ -286,12 +287,10 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: call_results.append(function_call(tools[tool_id]["class"], parameters)) return call_results - call_results = [] - for i, result in enumerate(tool_results): - call_results.extend(parse_tool_results(result)) - tool_results[i]["call_results"] = call_results + call_results = parse_tool_results(tool_results) + tool_results["call_results"] = call_results - call_results_str = "\n\n".join([str(e) for e in call_results if e is not None]) + call_results_str = str(call_results) _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str @@ -335,7 +334,11 @@ def self_reflect( tool_results=str(tool_result), final_answer=final_answer, ) - if issubclass(type(reflect_model), LMM): + if ( + issubclass(type(reflect_model), LMM) + and image is not None + and Path(image).suffix in [".jpg", ".jpeg", ".png"] + ): return reflect_model(prompt, image=image) # type: ignore return reflect_model(prompt) @@ -345,6 +348,56 @@ def parse_reflect(reflect: str) -> bool: return "finish" in reflect.lower() and len(reflect) < 100 +def visualize_result(all_tool_results: List[Dict]) -> List[str]: + image_to_data = {} + for tool_result in all_tool_results: + if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]: + continue + + parameters = tool_result["parameters"] + # parameters can either be a dictionary or list, parameters can also be malformed + # becaus the LLM builds them + if isinstance(parameters, dict): + if "image" not in parameters: + continue + parameters = [parameters] + elif isinstance(tool_result["parameters"], list): + if ( + len(tool_result["parameters"]) < 1 + and "image" not in tool_result["parameters"][0] + ): + continue + + for param, call_result in zip(parameters, tool_result["call_results"]): + + # calls can fail, so we need to check if the call was successful + if not isinstance(call_result, dict): + continue + if not "bboxes" in call_result: + continue + + # if the call was successful, then we can add the image data + image = param["image"] + if image not in image_to_data: + image_to_data[image] = {"bboxes": [], "masks": [], "labels": []} + + image_to_data[image]["bboxes"].extend(call_result["bboxes"]) + image_to_data[image]["labels"].extend(call_result["labels"]) + if "masks" in call_result: + image_to_data[image]["masks"].extend(call_result["masks"]) + + visualized_images = [] + for image in image_to_data: + image_path = Path(image) + image_data = image_to_data[image] + image = overlay_masks(image_path, image_data) + image = overlay_bboxes(image, image_data) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + image.save(f.name) + visualized_images.append(f.name) + return visualized_images + + class VisionAgent(Agent): r"""Vision Agent is an agent framework that utilizes tools as well as self reflection to accomplish tasks, in particular vision tasks. Vision Agent is based @@ -389,7 +442,8 @@ def __call__( """Invoke the vision agent. Parameters: - input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}]. + input: a prompt that describe the task or a conversation in the format of + [{"role": "user", "content": "describe your task here..."}]. image: the input image referenced in the prompt parameter. Returns: @@ -436,9 +490,8 @@ def chat_with_workflow( self.answer_model, task_str, call_results, previous_log, reflections ) - for tool_result in tool_results: - tool_result["answer"] = answer - all_tool_results.extend(tool_results) + tool_results["answer"] = answer + all_tool_results.append(tool_results) _LOGGER.info(f"\tAnswer: {answer}") answers.append({"task": task_str, "answer": answer}) @@ -448,13 +501,14 @@ def chat_with_workflow( self.answer_model, question, answers, reflections ) + visualized_images = visualize_result(all_tool_results) reflection = self_reflect( self.reflect_model, question, self.tools, all_tool_results, final_answer, - image, + visualized_images[0] if len(visualized_images) > 0 else image, ) _LOGGER.info(f"\tReflection: {reflection}") if parse_reflect(reflection): diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 05a129ce..849f912f 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -3,15 +3,38 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Tuple, Union +from typing import Dict, Tuple, Union import numpy as np -from PIL import Image +from PIL import Image, ImageDraw, ImageFont from PIL.Image import Image as ImageType +COLORS = [ + (158, 218, 229), + (219, 219, 141), + (23, 190, 207), + (188, 189, 34), + (199, 199, 199), + (247, 182, 210), + (127, 127, 127), + (227, 119, 194), + (196, 156, 148), + (197, 176, 213), + (140, 86, 75), + (148, 103, 189), + (255, 152, 150), + (152, 223, 138), + (214, 39, 40), + (44, 160, 44), + (255, 187, 120), + (174, 199, 232), + (255, 127, 14), + (31, 119, 180), +] + def b64_to_pil(b64_str: str) -> ImageType: - """Convert a base64 string to a PIL Image. + r"""Convert a base64 string to a PIL Image. Parameters: b64_str: the base64 encoded image @@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType: def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]: - """Get the size of an image. + r"""Get the size of an image. Parameters: data: the input image @@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: - """Convert an image to a base64 string. + r"""Convert an image to a base64 string. Parameters: data: the input image @@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: else: arr_bytes = data.tobytes() return base64.b64encode(arr_bytes).decode("utf-8") + + +def overlay_bboxes( + image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict +) -> ImageType: + r"""Plots bounding boxes on to an image. + + Parameters: + image: the input image + bboxes: the bounding boxes to overlay + + Returns: + The image with the bounding boxes overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])} + + draw = ImageDraw.Draw(image) + font = ImageFont.load_default() + width, height = image.size + if "bboxes" not in bboxes: + return image.convert("RGB") + + for label, box in zip(bboxes["labels"], bboxes["bboxes"]): + box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height] + draw.rectangle(box, outline=color[label], width=3) + label = f"{label}" + text_box = draw.textbbox((box[0], box[1]), text=label, font=font) + draw.rectangle(text_box, fill=color[label]) + draw.text((text_box[0], text_box[1]), label, fill="black", font=font) + return image.convert("RGB") + + +def overlay_masks( + image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5 +) -> ImageType: + r"""Plots masks on to an image. + + Parameters: + image: the input image + masks: the masks to overlay + alpha: the transparency of the overlay + + Returns: + The image with the masks overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])} + if "masks" not in masks: + return image.convert("RGB") + + for label, mask in zip(masks["labels"], masks["masks"]): + if isinstance(mask, str): + mask = np.array(Image.open(mask)) + np_mask = np.zeros((image.size[1], image.size[0], 4)) + np_mask[mask > 0, :] = color[label] + (255 * alpha,) + mask_img = Image.fromarray(np_mask.astype(np.uint8)) + image = Image.alpha_composite(image.convert("RGBA"), mask_img) + return image.convert("RGB")