Skip to content

Commit

Permalink
merge overlay count into overlay bbox
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 4, 2024
1 parent 0a087ce commit 9a1394e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 56 deletions.
1 change: 0 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
loca_zero_shot_counting,
ocr,
overlay_bounding_boxes,
overlay_counting_results,
overlay_heat_map,
overlay_segmentation_masks,
owl_v2_image,
Expand Down
95 changes: 40 additions & 55 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import cv2
import numpy as np
import requests
from PIL import Image, ImageDraw, ImageEnhance, ImageFont
from PIL import Image, ImageDraw, ImageFont
from pillow_heif import register_heif_opener # type: ignore
from pytube import YouTube # type: ignore

Expand Down Expand Up @@ -1917,30 +1917,36 @@ def overlay_bounding_boxes(
bboxes = bbox_int[i]
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)

width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(
resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")
),
fontsize,
)

for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle(
(box[0], box[1], text_box[2], text_box[3]), fill=color[label]
if len(bboxes) > 20:
pil_image = _plot_counting(pil_image, bboxes, color)
else:
width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(
resources.files("vision_agent.fonts").joinpath(
"default_font_ch_en.ttf"
)
),
fontsize,
)
draw.text((box[0], box[1]), text, fill="black", font=font)

for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle(
(box[0], box[1], text_box[2], text_box[3]), fill=color[label]
)
draw.text((box[0], box[1]), text, fill="black", font=font)

frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down Expand Up @@ -2099,39 +2105,19 @@ def overlay_heat_map(
return np.array(combined)


def overlay_counting_results(
image: np.ndarray, instances: List[Dict[str, Any]]
) -> np.ndarray:
"""'overlay_counting_results' is a utility function that displays counting results on
an image.
Parameters:
image (np.ndarray): The image to display the bounding boxes on.
instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding
box information of each instance
Returns:
np.ndarray: The image with the instance_id dislpayed
Example
-------
>>> image_with_bboxes = overlay_counting_results(
image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")
color = (158, 218, 229)

width, height = pil_image.size
def _plot_counting(
image: Image.Image,
bboxes: List[Dict[str, Any]],
colors: Dict[str, Tuple[int, int, int]],
) -> Image.Image:
width, height = image.size
fontsize = max(10, int(min(width, height) / 80))
pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5)
draw = ImageDraw.Draw(pil_image)
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)

for i, elt in enumerate(instances, 1):
for i, elt in enumerate(bboxes, 1):
label = f"{i}"
box = elt["bbox"]

Expand All @@ -2153,7 +2139,7 @@ def overlay_counting_results(
text_y1 = cy + text_height / 2

# Draw the rectangle encapsulating the text
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color)
draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=colors[elt["label"]])

# Draw the text at the center of the bounding box
draw.text(
Expand All @@ -2164,7 +2150,7 @@ def overlay_counting_results(
anchor="lt",
)

return np.array(pil_image)
return image


FUNCTION_TOOLS = [
Expand Down Expand Up @@ -2197,7 +2183,6 @@ def overlay_counting_results(
overlay_bounding_boxes,
overlay_segmentation_masks,
overlay_heat_map,
overlay_counting_results,
]

TOOLS = FUNCTION_TOOLS + UTIL_TOOLS
Expand Down

0 comments on commit 9a1394e

Please sign in to comment.