diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index fca3819c..35e8adb3 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1754,14 +1754,17 @@ def _save_video_to_result(video_uri: str) -> None: def overlay_bounding_boxes( - image: np.ndarray, bboxes: List[Dict[str, Any]] + medias: Union[np.ndarray, List[np.ndarray]], + bboxes: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], ) -> np.ndarray: """'overlay_bounding_boxes' is a utility function that displays bounding boxes on an image. Parameters: - image (np.ndarray): The image to display the bounding boxes on. - bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding + medias (Union[np.ndarray, List[np.ndarra]]): The image or frames to display the + bounding boxes on. + bboxes (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of + dictionaries or a list of list of dictionaries containing the bounding boxes. Returns: @@ -1773,41 +1776,54 @@ def overlay_bounding_boxes( image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}], ) """ - pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB") - if len(set([box["label"] for box in bboxes])) > len(COLORS): + medias_int: List[np.ndarray] = ( + [medias] if isinstance(medias, np.ndarray) else medias + ) + bbox_int = [bboxes] if isinstance(bboxes[0], dict) else bboxes + bbox_int = cast(List[List[Dict[str, Any]]], bbox_int) + labels = set([bb["label"] for b in bbox_int for bb in b]) + + if len(labels) > len(COLORS): _LOGGER.warning( "Number of unique labels exceeds the number of available colors. Some labels may have the same color." ) - color = { - label: COLORS[i % len(COLORS)] - for i, label in enumerate(set([box["label"] for box in bboxes])) - } - bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True) + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)} - width, height = pil_image.size - fontsize = max(12, int(min(width, height) / 40)) - draw = ImageDraw.Draw(pil_image) - font = ImageFont.truetype( - str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), - fontsize, - ) + frame_out = [] + for i, frame in enumerate(medias_int): + pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGB") - for elt in bboxes: - label = elt["label"] - box = elt["bbox"] - scores = elt["score"] + bboxes = bbox_int[i] + bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True) - # denormalize the box if it is normalized - box = denormalize_bbox(box, (height, width)) + width, height = pil_image.size + fontsize = max(12, int(min(width, height) / 40)) + draw = ImageDraw.Draw(pil_image) + font = ImageFont.truetype( + str( + resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf") + ), + fontsize, + ) - draw.rectangle(box, outline=color[label], width=4) - text = f"{label}: {scores:.2f}" - text_box = draw.textbbox((box[0], box[1]), text=text, font=font) - draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label]) - draw.text((box[0], box[1]), text, fill="black", font=font) - return np.array(pil_image) + for elt in bboxes: + label = elt["label"] + box = elt["bbox"] + scores = elt["score"] + + # denormalize the box if it is normalized + box = denormalize_bbox(box, (height, width)) + draw.rectangle(box, outline=color[label], width=4) + text = f"{label}: {scores:.2f}" + text_box = draw.textbbox((box[0], box[1]), text=text, font=font) + draw.rectangle( + (box[0], box[1], text_box[2], text_box[3]), fill=color[label] + ) + draw.text((box[0], box[1]), text, fill="black", font=font) + frame_out.append(np.array(pil_image)) + return frame_out[0] if len(frame_out) == 1 else frame_out def _get_text_coords_from_mask( @@ -1847,7 +1863,8 @@ def overlay_segmentation_masks( medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display the masks on. masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of - dictionaries containing the masks, labels and scores. + dictionaries or a list of list of dictionaries containing the masks, labels + and scores. draw_label (bool, optional): If True, the labels will be displayed on the image. secondary_label_key (str, optional): The key to use for the secondary tracking label which is needed in videos to display tracking information.