Skip to content

Commit

Permalink
added code to plot tracking labels which are dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Aug 27, 2024
1 parent 46343e0 commit 7165bf1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def overlay_segmentation_masks(
medias: Union[np.ndarray, List[np.ndarray]],
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
secondary_label_key: str = "tracking_label",
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand All @@ -1426,7 +1427,10 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks.
dictionaries containing the masks, labels and scores.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.
Returns:
np.ndarray: The image with the masks displayed.
Expand Down Expand Up @@ -1471,23 +1475,25 @@ def overlay_segmentation_masks(
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
tracking_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text_box = draw.textbbox((0, 0), text=label, font=font)
text = tracking_lbl if tracking_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=label, font=font)
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), label, fill="black", font=font)
draw.text((x, y), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down

0 comments on commit 7165bf1

Please sign in to comment.