From ecd8a3796195a4d95ad323c41a8605ae641aa595 Mon Sep 17 00:00:00 2001 From: Camilo Iral Date: Fri, 13 Sep 2024 13:47:35 -0500 Subject: [PATCH] fix overlay_segmentation_masks util --- vision_agent/tools/tools.py | 86 +++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 46 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 63927f01..869b3b9e 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -146,9 +146,7 @@ def grounding_dino( def owl_v2_image( - prompt: str, - image: np.ndarray, - box_threshold: float = 0.10, + prompt: str, image: np.ndarray, box_threshold: float = 0.10, ) -> List[Dict[str, Any]]: """'owl_v2_image' is a tool that can detect and count multiple objects given a text prompt such as category names or referring expressions on images. The categories in @@ -201,9 +199,7 @@ def owl_v2_image( def owl_v2_video( - prompt: str, - frames: List[np.ndarray], - box_threshold: float = 0.10, + prompt: str, frames: List[np.ndarray], box_threshold: float = 0.10, ) -> List[List[Dict[str, Any]]]: """'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple objects per frame given a text prompt sucha s a category name or referring @@ -581,9 +577,7 @@ def loca_visual_prompt_counting( def countgd_counting( - prompt: str, - image: np.ndarray, - box_threshold: float = 0.23, + prompt: str, image: np.ndarray, box_threshold: float = 0.23, ) -> List[Dict[str, Any]]: """'countgd_counting' is a tool that can precisely count multiple instances of an object given a text prompt. It returns a list of bounding boxes with normalized @@ -634,9 +628,7 @@ def countgd_counting( def countgd_example_based_counting( - visual_prompts: List[List[float]], - image: np.ndarray, - box_threshold: float = 0.23, + visual_prompts: List[List[float]], image: np.ndarray, box_threshold: float = 0.23, ) -> List[Dict[str, Any]]: """'countgd_example_based_counting' is a tool that can precisely count multiple instances of an object given few visual example prompts. It returns a list of bounding @@ -1491,7 +1483,7 @@ def closest_box_distance( horizontal_distance = np.max([0, x21 - x12, x11 - x22]) vertical_distance = np.max([0, y21 - y12, y11 - y22]) - return cast(float, np.sqrt(horizontal_distance**2 + vertical_distance**2)) + return cast(float, np.sqrt(horizontal_distance ** 2 + vertical_distance ** 2)) # Utility and visualization functions @@ -1753,6 +1745,7 @@ def overlay_segmentation_masks( masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], draw_label: bool = True, secondary_label_key: str = "tracking_label", + fontsize: Optional[int] = None, ) -> Union[np.ndarray, List[np.ndarray]]: """'overlay_segmentation_masks' is a utility function that displays segmentation masks. @@ -1761,10 +1754,11 @@ def overlay_segmentation_masks( medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display the masks on. masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of - dictionaries containing the masks, labels and scores. + dictionaries containing the masks and labels. draw_label (bool, optional): If True, the labels will be displayed on the image. secondary_label_key (str, optional): The key to use for the secondary tracking label which is needed in videos to display tracking information. + fontsize (Optional[int], optional): The font size to use for the labels in case of needing a custom size. Returns: np.ndarray: The image with the masks displayed. @@ -1774,7 +1768,6 @@ def overlay_segmentation_masks( >>> image_with_masks = overlay_segmentation_masks( image, [{ - 'score': 0.99, 'label': 'dinosaur', 'mask': array([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], @@ -1785,19 +1778,21 @@ def overlay_segmentation_masks( ) """ medias_int: List[np.ndarray] = ( - [medias] if isinstance(medias, np.ndarray) else medias + [media for media in medias] if isinstance(medias, np.ndarray) else medias ) masks_int = [masks] if isinstance(masks[0], dict) else masks masks_int = cast(List[List[Dict[str, Any]]], masks_int) labels = set() for mask_i in masks_int: - for mask_j in mask_i: - labels.add(mask_j["label"]) + if mask_i is not None: + for mask_j in mask_i: + labels.add(mask_j["label"]) color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)} width, height = Image.fromarray(medias_int[0]).size - fontsize = max(12, int(min(width, height) / 40)) + if fontsize is None: + fontsize = max(12, int(min(width, height) / 40)) font = ImageFont.truetype( str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), fontsize, @@ -1806,28 +1801,31 @@ def overlay_segmentation_masks( frame_out = [] for i, frame in enumerate(medias_int): pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA") - for elt in masks_int[i]: - mask = elt["mask"] - label = elt["label"] - tracking_lbl = elt.get(secondary_label_key, None) - np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) - np_mask[mask > 0, :] = color[label] + (255 * 0.5,) - mask_img = Image.fromarray(np_mask.astype(np.uint8)) - pil_image = Image.alpha_composite(pil_image, mask_img) - - if draw_label: - draw = ImageDraw.Draw(pil_image) - text = tracking_lbl if tracking_lbl else label - text_box = draw.textbbox((0, 0), text=text, font=font) - x, y = _get_text_coords_from_mask( - mask, - v_gap=(text_box[3] - text_box[1]) + 10, - h_gap=(text_box[2] - text_box[0]) // 2, - ) - if x != 0 and y != 0: - text_box = draw.textbbox((x, y), text=text, font=font) - draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label]) - draw.text((x, y), text, fill="black", font=font) + if masks_int[i] is not None: + for elt in masks_int[i]: + mask = elt["mask"] + label = elt["label"] + tracking_lbl = elt.get(secondary_label_key, None) + np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) + np_mask[mask > 0, :] = color[label] + (255 * 0.5,) + mask_img = Image.fromarray(np_mask.astype(np.uint8)) + pil_image = Image.alpha_composite(pil_image, mask_img) + + if draw_label and label is not None: + draw = ImageDraw.Draw(pil_image) + text = tracking_lbl if tracking_lbl else label + text_box = draw.textbbox((0, 0), text=text, font=font) + x, y = get_text_coords_from_mask( + mask, + v_gap=(text_box[3] - text_box[1]) + 10, + h_gap=(text_box[2] - text_box[0]) // 2, + ) + if x != 0 and y != 0: + text_box = draw.textbbox((x, y), text=text, font=font) + draw.rectangle( + (x, y, text_box[2], text_box[3]), fill=color[label] + ) + draw.text((x, y), text, fill="black", font=font) frame_out.append(np.array(pil_image)) return frame_out[0] if len(frame_out) == 1 else frame_out @@ -1935,11 +1933,7 @@ def overlay_counting_results( # Draw the text at the center of the bounding box draw.text( - (text_x0, text_y0), - label, - fill="black", - font=font, - anchor="lt", + (text_x0, text_y0), label, fill="black", font=font, anchor="lt", ) return np.array(pil_image)