Skip to content

Commit

Permalink
Add Count tools (#56)
Browse files Browse the repository at this point in the history
* Adding counting tools to vision agent

* fixed heatmap overlay and addressesessed PR comments

* adding the counting tool to take both absolute coordinate and normalized coordinates, refactoring code, adding llm generate counter tool

* fix linting
  • Loading branch information
shankar-vision-eng authored Apr 22, 2024
1 parent a860ebe commit 5f11aea
Show file tree
Hide file tree
Showing 7 changed files with 273 additions and 49 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
Many current vision problems can easily take hours or days to solve, you need to find the
right model, figure out how to use it, possibly write programming logic around it to
right model, figure out how to use it, possibly write programming logic around it to
accomplish the task you want or even more expensive, train your own model. Vision Agent
aims to provide an in-seconds experience by allowing users to describe their problem in
text and utilizing agent frameworks to solve the task for them. Check out our discord
Expand Down Expand Up @@ -108,6 +108,9 @@ you. For example:
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
Expand Down
43 changes: 30 additions & 13 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -336,7 +336,9 @@ def _handle_viz_tools(

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
):
return image_to_data

# if the call was successful, then we can add the image data
Expand All @@ -349,11 +351,12 @@ def _handle_viz_tools(
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

return image_to_data

Expand All @@ -367,6 +370,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"grounding_dino_",
"extract_frames_",
"dinov_",
"zero_shot_counting_",
"visual_prompt_counting_",
]:
continue

Expand All @@ -379,8 +384,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
for image_str in image_to_data:
image_path = Path(image_str)
image_data = image_to_data[image_str]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
if "_counting_" in tool_result["tool_name"]:
image = overlay_heat_map(image_path, image_data)
else:
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
Expand Down Expand Up @@ -484,11 +492,21 @@ def chat_with_workflow(
if image:
question += f" Image name: {image}"
if reference_data:
if not ("image" in reference_data and "mask" in reference_data):
if not (
"image" in reference_data
and ("mask" in reference_data or "bbox" in reference_data)
):
raise ValueError(
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
)
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
visual_prompt_data = (
f"Reference mask: {reference_data['mask']}"
if "mask" in reference_data
else f"Reference bbox: {reference_data['bbox']}"
)
question += (
f" Reference image: {reference_data['image']}, {visual_prompt_data}"
)

reflections = ""
final_answer = ""
Expand Down Expand Up @@ -531,7 +549,6 @@ def chat_with_workflow(
final_answer = answer_summarize(
self.answer_model, question, answers, reflections
)

visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
Expand Down
101 changes: 96 additions & 5 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from importlib import resources
from io import BytesIO
from pathlib import Path
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, List

import numpy as np
from PIL import Image, ImageDraw, ImageFont
Expand Down Expand Up @@ -34,6 +34,35 @@
]


def normalize_bbox(
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
r"""Normalize the bounding box coordinates to be between 0 and 1."""
x1, y1, x2, y2 = bbox
x1 = round(x1 / image_size[1], 2)
y1 = round(y1 / image_size[0], 2)
x2 = round(x2 / image_size[1], 2)
y2 = round(y2 / image_size[0], 2)
return [x1, y1, x2, y2]


def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
r"""Decode a run-length encoded mask. Returns numpy array, 1 - mask, 0 - background.
Parameters:
mask_rle: Run-length as string formated (start length)
shape: The (height, width) of array to return
"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape)


def b64_to_pil(b64_str: str) -> ImageType:
r"""Convert a base64 string to a PIL Image.
Expand Down Expand Up @@ -86,6 +115,26 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
return base64.b64encode(arr_bytes).decode("utf-8")


def denormalize_bbox(
bbox: List[Union[int, float]], image_size: Tuple[int, ...]
) -> List[float]:
r"""DeNormalize the bounding box coordinates so that they are in absolute values."""

if len(bbox) != 4:
raise ValueError("Bounding box must be of length 4.")

arr = np.array(bbox)
if np.all((arr >= 0) & (arr <= 1)):
x1, y1, x2, y2 = bbox
x1 = round(x1 * image_size[1])
y1 = round(y1 * image_size[0])
x2 = round(x2 * image_size[1])
y2 = round(y2 * image_size[0])
return [x1, y1, x2, y2]
else:
return bbox


def overlay_bboxes(
image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict
) -> ImageType:
Expand All @@ -103,6 +152,9 @@ def overlay_bboxes(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "bboxes" not in bboxes:
return image.convert("RGB")

color = {
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
}
Expand All @@ -114,8 +166,6 @@ def overlay_bboxes(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
box = [
Expand Down Expand Up @@ -150,11 +200,15 @@ def overlay_masks(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
return image.convert("RGB")

if "labels" not in masks:
masks["labels"] = [""] * len(masks["masks"])

color = {
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
}
if "masks" not in masks:
return image.convert("RGB")

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
Expand All @@ -164,3 +218,40 @@ def overlay_masks(
mask_img = Image.fromarray(np_mask.astype(np.uint8))
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
return image.convert("RGB")


def overlay_heat_map(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
) -> ImageType:
r"""Plots heat map on to an image.
Parameters:
image: the input image
masks: the heatmap to overlay
alpha: the transparency of the overlay
Returns:
The image with the heatmap overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
return image.convert("RGB")

# Only one heat map per image, so no need to loop through masks
image = image.convert("L")

if isinstance(masks["masks"][0], str):
mask = b64_to_pil(masks["masks"][0])

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
odraw.bitmap(
(0, 0), mask, fill=(255, 0, 0, round(alpha * 255))
) # fill=(R, G, B, Alpha)
combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size))

return combined.convert("RGB")
4 changes: 4 additions & 0 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
)


Expand Down Expand Up @@ -127,6 +128,9 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})


class AzureOpenAILLM(OpenAILLM):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SYSTEM_PROMPT,
GroundingDINO,
GroundingSAM,
ZeroShotCounting,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -272,6 +273,9 @@ def generate_segmentor(self, question: str) -> Callable:

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})

def generate_zero_shot_counter(self, question: str) -> Callable:
return lambda x: ZeroShotCounting()(**{"image": x})


class AzureOpenAILMM(OpenAILMM):
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
SegArea,
SegIoU,
Tool,
Expand Down
Loading

0 comments on commit 5f11aea

Please sign in to comment.