Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Tools to Handle Video #32

Merged
merged 5 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import tempfile

import numpy as np
from PIL import Image

from vision_agent.tools.tools import BboxIoU, SegIoU


def test_bbox_iou():
bbox1 = [0, 0, 0.75, 0.75]
bbox2 = [0.25, 0.25, 1, 1]
assert BboxIoU()(bbox1, bbox2) == 0.29


def test_seg_iou():
mask1 = np.zeros((10, 10), dtype=np.uint8)
mask1[2:4, 2:4] = 255
mask2 = np.zeros((10, 10), dtype=np.uint8)
mask2[3:5, 3:5] = 255
with tempfile.TemporaryDirectory() as tmpdir:
mask1_path = os.path.join(tmpdir, "mask1.png")
mask2_path = os.path.join(tmpdir, "mask2.png")
Image.fromarray(mask1).save(mask1_path)
Image.fromarray(mask2).save(mask2_path)
assert SegIoU()(mask1_path, mask2_path) == 0.14
92 changes: 73 additions & 19 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging
import sys
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -248,12 +250,12 @@ def retrieval(
tools: Dict[int, Any],
previous_log: str,
reflections: str,
) -> Tuple[List[Dict], str]:
) -> Tuple[Dict, str]:
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}, reflections
)
if tool_id is None:
return [{}], ""
return {}, ""
_LOGGER.info(f"\t(Tool ID, name): ({tool_id}, {tools[tool_id]['name']})")

tool_instructions = tools[tool_id]
Expand All @@ -265,14 +267,12 @@ def retrieval(
)
_LOGGER.info(f"\tParameters: {parameters} for {tool_name}")
if parameters is None:
return [{}], ""
tool_results = [
{"task": question, "tool_name": tool_name, "parameters": parameters}
]
return {}, ""
tool_results = {"task": question, "tool_name": tool_name, "parameters": parameters}

_LOGGER.info(
f"""Going to run the following {len(tool_results)} tool(s) in sequence:
{tabulate(tool_results, headers="keys", tablefmt="mixed_grid")}"""
f"""Going to run the following tool(s) in sequence:
{tabulate([tool_results], headers="keys", tablefmt="mixed_grid")}"""
)

def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
Expand All @@ -286,12 +286,10 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results.append(function_call(tools[tool_id]["class"], parameters))
return call_results

call_results = []
for i, result in enumerate(tool_results):
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results
call_results = parse_tool_results(tool_results)
tool_results["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
call_results_str = str(call_results)
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str

Expand Down Expand Up @@ -335,7 +333,11 @@ def self_reflect(
tool_results=str(tool_result),
final_answer=final_answer,
)
if issubclass(type(reflect_model), LMM):
if (
issubclass(type(reflect_model), LMM)
and image is not None
and Path(image).suffix in [".jpg", ".jpeg", ".png"]
):
return reflect_model(prompt, image=image) # type: ignore
return reflect_model(prompt)

Expand All @@ -345,6 +347,56 @@ def parse_reflect(reflect: str) -> bool:
return "finish" in reflect.lower() and len(reflect) < 100


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if not tool_result["tool_name"] in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if (
len(tool_result["parameters"]) < 1
and "image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):

# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
if "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {"bboxes": [], "masks": [], "labels": []}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

visualized_images = []
for image in image_to_data:
image_path = Path(image)
image_data = image_to_data[image]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
return visualized_images


class VisionAgent(Agent):
r"""Vision Agent is an agent framework that utilizes tools as well as self
reflection to accomplish tasks, in particular vision tasks. Vision Agent is based
Expand Down Expand Up @@ -389,7 +441,8 @@ def __call__(
"""Invoke the vision agent.

Parameters:
input: a prompt that describe the task or a conversation in the format of [{"role": "user", "content": "describe your task here..."}].
input: a prompt that describe the task or a conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the prompt parameter.

Returns:
Expand Down Expand Up @@ -436,9 +489,8 @@ def chat_with_workflow(
self.answer_model, task_str, call_results, previous_log, reflections
)

for tool_result in tool_results:
tool_result["answer"] = answer
all_tool_results.extend(tool_results)
tool_results["answer"] = answer
all_tool_results.append(tool_results)

_LOGGER.info(f"\tAnswer: {answer}")
answers.append({"task": task_str, "answer": answer})
Expand All @@ -448,13 +500,15 @@ def chat_with_workflow(
self.answer_model, question, answers, reflections
)

visualized_images = visualize_result(all_tool_results)
all_tool_results.append({"visualized_images": visualized_images})
reflection = self_reflect(
self.reflect_model,
question,
self.tools,
all_tool_results,
final_answer,
image,
visualized_images[0] if len(visualized_images) > 0 else image,
)
_LOGGER.info(f"\tReflection: {reflection}")
if parse_reflect(reflection):
Expand Down
100 changes: 95 additions & 5 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,38 @@
import base64
from io import BytesIO
from pathlib import Path
from typing import Tuple, Union
from typing import Dict, Tuple, Union

import numpy as np
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

COLORS = [
(158, 218, 229),
(219, 219, 141),
(23, 190, 207),
(188, 189, 34),
(199, 199, 199),
(247, 182, 210),
(127, 127, 127),
(227, 119, 194),
(196, 156, 148),
(197, 176, 213),
(140, 86, 75),
(148, 103, 189),
(255, 152, 150),
(152, 223, 138),
(214, 39, 40),
(44, 160, 44),
(255, 187, 120),
(174, 199, 232),
(255, 127, 14),
(31, 119, 180),
]


def b64_to_pil(b64_str: str) -> ImageType:
"""Convert a base64 string to a PIL Image.
r"""Convert a base64 string to a PIL Image.

Parameters:
b64_str: the base64 encoded image
Expand All @@ -26,7 +49,7 @@ def b64_to_pil(b64_str: str) -> ImageType:


def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
"""Get the size of an image.
r"""Get the size of an image.

Parameters:
data: the input image
Expand All @@ -41,7 +64,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int,


def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
"""Convert an image to a base64 string.
r"""Convert an image to a base64 string.

Parameters:
data: the input image
Expand All @@ -60,3 +83,70 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")


def overlay_bboxes(
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
) -> ImageType:
r"""Plots bounding boxes on to an image.

Parameters:
image: the input image
bboxes: the bounding boxes to overlay

Returns:
The image with the bounding boxes overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(bboxes["labels"])}

draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
width, height = image.size
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box in zip(bboxes["labels"], bboxes["bboxes"]):
box = [box[0] * width, box[1] * height, box[2] * width, box[3] * height]
draw.rectangle(box, outline=color[label], width=3)
label = f"{label}"
text_box = draw.textbbox((box[0], box[1]), text=label, font=font)
draw.rectangle(text_box, fill=color[label])
draw.text((text_box[0], text_box[1]), label, fill="black", font=font)
return image.convert("RGB")


def overlay_masks(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.5
) -> ImageType:
r"""Plots masks on to an image.

Parameters:
image: the input image
masks: the masks to overlay
alpha: the transparency of the overlay

Returns:
The image with the masks overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(masks["labels"])}
if "masks" not in masks:
return image.convert("RGB")

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
return image.convert("RGB")
15 changes: 14 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
from .tools import (
CLIP,
TOOLS,
BboxArea,
BboxIoU,
Counter,
Crop,
ExtractFrames,
GroundingDINO,
GroundingSAM,
SegArea,
SegIoU,
Tool,
)
Loading
Loading