diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 9a0ba954..bbc7a833 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1387,6 +1387,7 @@ def overlay_segmentation_masks( medias: Union[np.ndarray, List[np.ndarray]], masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]], draw_label: bool = True, + secondary_label_key: str = "new_label", ) -> Union[np.ndarray, List[np.ndarray]]: """'overlay_segmentation_masks' is a utility function that displays segmentation masks. @@ -1425,6 +1426,7 @@ def overlay_segmentation_masks( for mask_i in masks_int: for mask_j in mask_i: labels.add(mask_j["label"]) + color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)} width, height = Image.fromarray(medias_int[0]).size @@ -1440,6 +1442,7 @@ def overlay_segmentation_masks( for elt in masks_int[i]: mask = elt["mask"] label = elt["label"] + secondary_lbl = elt.get(secondary_label_key, None) np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * 0.5,) mask_img = Image.fromarray(np_mask.astype(np.uint8)) @@ -1447,16 +1450,17 @@ def overlay_segmentation_masks( if draw_label: draw = ImageDraw.Draw(pil_image) - text_box = draw.textbbox((0, 0), text=label, font=font) + text = secondary_lbl if secondary_lbl else label + text_box = draw.textbbox((0, 0), text=text, font=font) x, y = _get_text_coords_from_mask( mask, v_gap=(text_box[3] - text_box[1]) + 10, h_gap=(text_box[2] - text_box[0]) // 2, ) if x != 0 and y != 0: - text_box = draw.textbbox((x, y), text=label, font=font) + text_box = draw.textbbox((x, y), text=text, font=font) draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label]) - draw.text((x, y), label, fill="black", font=font) + draw.text((x, y), text, fill="black", font=font) frame_out.append(np.array(pil_image)) return frame_out[0] if len(frame_out) == 1 else frame_out