diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index e5b7c334..da74f677 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -45,7 +45,6 @@ loca_zero_shot_counting, ocr, overlay_bounding_boxes, - overlay_counting_results, overlay_heat_map, overlay_segmentation_masks, owl_v2_image, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index f83132a5..b2b8a985 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -13,7 +13,7 @@ import cv2 import numpy as np import requests -from PIL import Image, ImageDraw, ImageEnhance, ImageFont +from PIL import Image, ImageDraw, ImageFont from pillow_heif import register_heif_opener # type: ignore from pytube import YouTube # type: ignore @@ -1917,30 +1917,36 @@ def overlay_bounding_boxes( bboxes = bbox_int[i] bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True) - width, height = pil_image.size - fontsize = max(12, int(min(width, height) / 40)) - draw = ImageDraw.Draw(pil_image) - font = ImageFont.truetype( - str( - resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf") - ), - fontsize, - ) - - for elt in bboxes: - label = elt["label"] - box = elt["bbox"] - scores = elt["score"] - - # denormalize the box if it is normalized - box = denormalize_bbox(box, (height, width)) - draw.rectangle(box, outline=color[label], width=4) - text = f"{label}: {scores:.2f}" - text_box = draw.textbbox((box[0], box[1]), text=text, font=font) - draw.rectangle( - (box[0], box[1], text_box[2], text_box[3]), fill=color[label] + if len(bboxes) > 20: + pil_image = _plot_counting(pil_image, bboxes, color) + else: + width, height = pil_image.size + fontsize = max(12, int(min(width, height) / 40)) + draw = ImageDraw.Draw(pil_image) + font = ImageFont.truetype( + str( + resources.files("vision_agent.fonts").joinpath( + "default_font_ch_en.ttf" + ) + ), + fontsize, ) - draw.text((box[0], box[1]), text, fill="black", font=font) + + for elt in bboxes: + label = elt["label"] + box = elt["bbox"] + scores = elt["score"] + + # denormalize the box if it is normalized + box = denormalize_bbox(box, (height, width)) + draw.rectangle(box, outline=color[label], width=4) + text = f"{label}: {scores:.2f}" + text_box = draw.textbbox((box[0], box[1]), text=text, font=font) + draw.rectangle( + (box[0], box[1], text_box[2], text_box[3]), fill=color[label] + ) + draw.text((box[0], box[1]), text, fill="black", font=font) + frame_out.append(np.array(pil_image)) return frame_out[0] if len(frame_out) == 1 else frame_out @@ -2099,39 +2105,19 @@ def overlay_heat_map( return np.array(combined) -def overlay_counting_results( - image: np.ndarray, instances: List[Dict[str, Any]] -) -> np.ndarray: - """'overlay_counting_results' is a utility function that displays counting results on - an image. - - Parameters: - image (np.ndarray): The image to display the bounding boxes on. - instances (List[Dict[str, Any]]): A list of dictionaries containing the bounding - box information of each instance - - Returns: - np.ndarray: The image with the instance_id dislpayed - - Example - ------- - >>> image_with_bboxes = overlay_counting_results( - image, [{'score': 0.99, 'label': 'object', 'bbox': [0.1, 0.11, 0.35, 0.4]}], - ) - """ - pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") - color = (158, 218, 229) - - width, height = pil_image.size +def _plot_counting( + image: Image.Image, + bboxes: List[Dict[str, Any]], + colors: Dict[str, Tuple[int, int, int]], +) -> Image.Image: + width, height = image.size fontsize = max(10, int(min(width, height) / 80)) - pil_image = ImageEnhance.Brightness(pil_image).enhance(0.5) - draw = ImageDraw.Draw(pil_image) + draw = ImageDraw.Draw(image) font = ImageFont.truetype( str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), fontsize, ) - - for i, elt in enumerate(instances, 1): + for i, elt in enumerate(bboxes, 1): label = f"{i}" box = elt["bbox"] @@ -2153,7 +2139,7 @@ def overlay_counting_results( text_y1 = cy + text_height / 2 # Draw the rectangle encapsulating the text - draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=color) + draw.rectangle((text_x0, text_y0, text_x1, text_y1), fill=colors[elt["label"]]) # Draw the text at the center of the bounding box draw.text( @@ -2164,7 +2150,7 @@ def overlay_counting_results( anchor="lt", ) - return np.array(pil_image) + return image FUNCTION_TOOLS = [ @@ -2197,7 +2183,6 @@ def overlay_counting_results( overlay_bounding_boxes, overlay_segmentation_masks, overlay_heat_map, - overlay_counting_results, ] TOOLS = FUNCTION_TOOLS + UTIL_TOOLS