Skip to content

Commit

Permalink
add image visualization for reflection
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 29, 2024
1 parent 98482e9 commit 0507f6a
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 24 deletions.
File renamed without changes.
92 changes: 73 additions & 19 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import logging
import sys
import tempfile
from os import walk
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -248,12 +251,12 @@ def retrieval(
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[List[Dict], str]:
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}, reflections
)
if tool_id is None:
return [{}], ""
return {}, ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
Expand All @@ -265,14 +268,12 @@ def retrieval(
)
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return [{}], ""
tool_results = [
{"task": question, "tool_name": tool_name, "parameters": parameters}
]
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}

_LOGGER.info(
f"""Going to run the following {len(tool_results)} tool(s) in sequence:
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand All @@ -286,12 +287,10 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results.append(function_call(tools[tool_id]["class"], parameters))
return call_results

call_results = []
for i, result in enumerate(tool_results):
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results
call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
call_results_str = str(call_results)
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str

Expand Down Expand Up @@ -335,7 +334,11 @@ def self_reflect(
tool_results=str(tool_result),
final_answer=final_answer,
)
if issubclass(type(reflect_model), LMM):
if (
issubclass(type(reflect_model), LMM)
and image is not None
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
):
return reflect_model(prompt, image=image) # type: ignore
return reflect_model(prompt)

Expand All @@ -345,6 +348,56 @@ def parse_reflect(reflect: str) -> bool:
return "finish" in reflect.lower() and len(reflect) < 100


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data = {}
for tool_result in all_tool_results:
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if (
len(tool_result["parameters"]) < 1
and "image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):

# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
if not "bboxes" in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

visualized_images = []
for image in image_to_data:
image_path = Path(image)
image_data = image_to_data[image]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
return visualized_images


class VisionAgent(Agent):
r"""Vision Agent is an agent framework that utilizes tools as well as self
reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
Expand Down Expand Up @@ -389,7 +442,8 @@ def __call__(
"""Invoke the vision agent.
Parameters:
input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
input: a prompt that describe the task or a conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the prompt parameter.
Returns:
Expand Down Expand Up @@ -436,9 +490,8 @@ def chat_with_workflow(
self.answer_model, task_str, call_results, previous_log, reflections
)

for tool_result in tool_results:
tool_result["answer"] = answer
all_tool_results.extend(tool_results)
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
Expand All @@ -448,13 +501,14 @@ def chat_with_workflow(
self.answer_model, question, answers, reflections
)

visualized_images = visualize_result(all_tool_results)
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
image,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"\tReflection: {reflection}")
if parse_reflect(reflection):
Expand Down
100 changes: 95 additions & 5 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,38 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, Tuple, Union

import numpy as np
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

COLORS = [
(158, 218, 229),
(219, 219, 141),
(23, 190, 207),
(188, 189, 34),
(199, 199, 199),
(247, 182, 210),
(127, 127, 127),
(227, 119, 194),
(196, 156, 148),
(197, 176, 213),
(140, 86, 75),
(148, 103, 189),
(255, 152, 150),
(152, 223, 138),
(214, 39, 40),
(44, 160, 44),
(255, 187, 120),
(174, 199, 232),
(255, 127, 14),
(31, 119, 180),
]


def b64_to_pil(b64_str: str) -> ImageType:
"""Convert a base64 string to a PIL Image.
r"""Convert a base64 string to a PIL Image.
Parameters:
b64_str: the base64 encoded image
Expand All @@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:


def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
"""Get the size of an image.
r"""Get the size of an image.
Parameters:
data: the input image
Expand All @@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,


def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
"""Convert an image to a base64 string.
r"""Convert an image to a base64 string.
Parameters:
data: the input image
Expand All @@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")


def overlay_bboxes(
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
) -> ImageType:
r"""Plots bounding boxes on to an image.
Parameters:
image: the input image
bboxes: the bounding boxes to overlay
Returns:
The image with the bounding boxes overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}

draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
width, height = image.size
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
draw.rectangle(box, outline=color[label], width=3)
label = f"{label}"
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
draw.rectangle(text_box, fill=color[label])
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
return image.convert("RGB")


def overlay_masks(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
) -> ImageType:
r"""Plots masks on to an image.
Parameters:
image: the input image
masks: the masks to overlay
alpha: the transparency of the overlay
Returns:
The image with the masks overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
if "masks" not in masks:
return image.convert("RGB")

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
return image.convert("RGB")

0 comments on commit 0507f6a

Please sign in to comment.