Skip to content

Commit

Permalink
Update Tools to Handle Video (#32)
Browse files Browse the repository at this point in the history
* added iou tools

* add image visualization for reflection

* update tool return format

* typing and flake8 issues

* added visualized images to all_tool_results
  • Loading branch information
dillonalaird authored Mar 29, 2024
1 parent fdae1fd commit bbe0a68
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 89 deletions.
26 changes: 26 additions & 0 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import tempfile

import numpy as np
from PIL import Image

from vision_agent.tools.tools import BboxIoU, SegIoU


def test_bbox_iou():
bbox1 = [0, 0, 0.75, 0.75]
bbox2 = [0.25, 0.25, 1, 1]
assert BboxIoU()(bbox1, bbox2) == 0.29


def test_seg_iou():
mask1 = np.zeros((10, 10), dtype=np.uint8)
mask1[2:4, 2:4] = 255
mask2 = np.zeros((10, 10), dtype=np.uint8)
mask2[3:5, 3:5] = 255
with tempfile.TemporaryDirectory() as tmpdir:
mask1_path = os.path.join(tmpdir, "mask1.png")
mask2_path = os.path.join(tmpdir, "mask2.png")
Image.fromarray(mask1).save(mask1_path)
Image.fromarray(mask2).save(mask2_path)
assert SegIoU()(mask1_path, mask2_path) == 0.14
92 changes: 73 additions & 19 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -248,12 +250,12 @@ def retrieval(
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[List[Dict], str]:
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}, reflections
)
if tool_id is None:
return [{}], ""
return {}, ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
Expand All @@ -265,14 +267,12 @@ def retrieval(
)
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return [{}], ""
tool_results = [
{"task": question, "tool_name": tool_name, "parameters": parameters}
]
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}

_LOGGER.info(
f"""Going to run the following {len(tool_results)} tool(s) in sequence:
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand All @@ -286,12 +286,10 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results.append(function_call(tools[tool_id]["class"], parameters))
return call_results

call_results = []
for i, result in enumerate(tool_results):
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results
call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
call_results_str = str(call_results)
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str

Expand Down Expand Up @@ -335,7 +333,11 @@ def self_reflect(
tool_results=str(tool_result),
final_answer=final_answer,
)
if issubclass(type(reflect_model), LMM):
if (
issubclass(type(reflect_model), LMM)
and image is not None
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
):
return reflect_model(prompt, image=image) # type: ignore
return reflect_model(prompt)

Expand All @@ -345,6 +347,56 @@ def parse_reflect(reflect: str) -> bool:
return "finish" in reflect.lower() and len(reflect) < 100


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if (
len(tool_result["parameters"]) < 1
and "image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):

# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
if "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

visualized_images = []
for image in image_to_data:
image_path = Path(image)
image_data = image_to_data[image]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
return visualized_images


class VisionAgent(Agent):
r"""Vision Agent is an agent framework that utilizes tools as well as self
reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
Expand Down Expand Up @@ -389,7 +441,8 @@ def __call__(
"""Invoke the vision agent.
Parameters:
input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
input: a prompt that describe the task or a conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the prompt parameter.
Returns:
Expand Down Expand Up @@ -436,9 +489,8 @@ def chat_with_workflow(
self.answer_model, task_str, call_results, previous_log, reflections
)

for tool_result in tool_results:
tool_result["answer"] = answer
all_tool_results.extend(tool_results)
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
Expand All @@ -448,13 +500,15 @@ def chat_with_workflow(
self.answer_model, question, answers, reflections
)

visualized_images = visualize_result(all_tool_results)
all_tool_results.append({"visualized_images": visualized_images})
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
image,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"\tReflection: {reflection}")
if parse_reflect(reflection):
Expand Down
100 changes: 95 additions & 5 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,38 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, Tuple, Union

import numpy as np
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

COLORS = [
(158, 218, 229),
(219, 219, 141),
(23, 190, 207),
(188, 189, 34),
(199, 199, 199),
(247, 182, 210),
(127, 127, 127),
(227, 119, 194),
(196, 156, 148),
(197, 176, 213),
(140, 86, 75),
(148, 103, 189),
(255, 152, 150),
(152, 223, 138),
(214, 39, 40),
(44, 160, 44),
(255, 187, 120),
(174, 199, 232),
(255, 127, 14),
(31, 119, 180),
]


def b64_to_pil(b64_str: str) -> ImageType:
"""Convert a base64 string to a PIL Image.
r"""Convert a base64 string to a PIL Image.
Parameters:
b64_str: the base64 encoded image
Expand All @@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:


def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
"""Get the size of an image.
r"""Get the size of an image.
Parameters:
data: the input image
Expand All @@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,


def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
"""Convert an image to a base64 string.
r"""Convert an image to a base64 string.
Parameters:
data: the input image
Expand All @@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")


def overlay_bboxes(
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
) -> ImageType:
r"""Plots bounding boxes on to an image.
Parameters:
image: the input image
bboxes: the bounding boxes to overlay
Returns:
The image with the bounding boxes overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}

draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
width, height = image.size
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
draw.rectangle(box, outline=color[label], width=3)
label = f"{label}"
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
draw.rectangle(text_box, fill=color[label])
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
return image.convert("RGB")


def overlay_masks(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
) -> ImageType:
r"""Plots masks on to an image.
Parameters:
image: the input image
masks: the masks to overlay
alpha: the transparency of the overlay
Returns:
The image with the masks overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
if "masks" not in masks:
return image.convert("RGB")

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
return image.convert("RGB")
15 changes: 14 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
from .tools import (
CLIP,
TOOLS,
BboxArea,
BboxIoU,
Counter,
Crop,
ExtractFrames,
GroundingDINO,
GroundingSAM,
SegArea,
SegIoU,
Tool,
)
Loading

0 comments on commit bbe0a68

Please sign in to comment.