From 7165bf141a02dfe9157ff44e1522eb63627c0b0f Mon Sep 17 00:00:00 2001 From: shankar-landing-ai Date: Tue, 27 Aug 2024 12:31:49 -0700 Subject: [PATCH] added code to plot tracking labels which are dynamic --- vision_agent/tools/tools.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 250d6d78..4e1a0f40 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1418,6 +1418,7 @@ def overlay_segmentation_masks( medias: Union[np.ndarray, List[np.ndarray]], masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], draw_label: bool = True, + secondary_label_key: str = "tracking_label", ) -> Union[np.ndarray, List[np.ndarray]]: """'overlay_segmentation_masks' is a utility function that displays segmentation masks. @@ -1426,7 +1427,10 @@ def overlay_segmentation_masks( medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display the masks on. masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of - dictionaries containing the masks. + dictionaries containing the masks, labels and scores. + draw_label (bool, optional): If True, the labels will be displayed on the image. + secondary_label_key (str, optional): The key to use for the secondary + tracking label which is needed in videos to display tracking information. Returns: np.ndarray: The image with the masks displayed. @@ -1471,6 +1475,7 @@ def overlay_segmentation_masks( for elt in masks_int[i]: mask = elt["mask"] label = elt["label"] + tracking_lbl = elt.get(secondary_label_key, None) np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * 0.5,) mask_img = Image.fromarray(np_mask.astype(np.uint8)) @@ -1478,16 +1483,17 @@ def overlay_segmentation_masks( if draw_label: draw = ImageDraw.Draw(pil_image) - text_box = draw.textbbox((0, 0), text=label, font=font) + text = tracking_lbl if tracking_lbl else label + text_box = draw.textbbox((0, 0), text=text, font=font) x, y = _get_text_coords_from_mask( mask, v_gap=(text_box[3] - text_box[1]) + 10, h_gap=(text_box[2] - text_box[0]) // 2, ) if x != 0 and y != 0: - text_box = draw.textbbox((x, y), text=label, font=font) + text_box = draw.textbbox((x, y), text=text, font=font) draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label]) - draw.text((x, y), label, fill="black", font=font) + draw.text((x, y), text, fill="black", font=font) frame_out.append(np.array(pil_image)) return frame_out[0] if len(frame_out) == 1 else frame_out