Skip to content

Commit

Permalink
adding secondary label to the seg masks so that color does not change…
Browse files Browse the repository at this point in the history
… when the label changes
  • Loading branch information
shankar-vision-eng committed Aug 20, 2024
1 parent aad2a00 commit 4131742
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ def overlay_segmentation_masks(
medias: Union[np.ndarray, List[np.ndarray]],
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
secondary_label_key: str = "new_label",
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand Down Expand Up @@ -1425,6 +1426,7 @@ def overlay_segmentation_masks(
for mask_i in masks_int:
for mask_j in mask_i:
labels.add(mask_j["label"])

color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

width, height = Image.fromarray(medias_int[0]).size
Expand All @@ -1440,23 +1442,25 @@ def overlay_segmentation_masks(
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
secondary_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text_box = draw.textbbox((0, 0), text=label, font=font)
text = secondary_lbl if secondary_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=label, font=font)
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), label, fill="black", font=font)
draw.text((x, y), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down

0 comments on commit 4131742

Please sign in to comment.