Skip to content

Commit

Permalink
fix overlay_segmentation_masks util
Browse files Browse the repository at this point in the history
  • Loading branch information
Camilo Iral committed Sep 13, 2024
1 parent b66a8c9 commit ecd8a37
Showing 1 changed file with 40 additions and 46 deletions.
86 changes: 40 additions & 46 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def grounding_dino(


def owl_v2_image(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.10,
prompt: str, image: np.ndarray, box_threshold: float = 0.10,
) -> List[Dict[str, Any]]:
"""'owl_v2_image' is a tool that can detect and count multiple objects given a text
prompt such as category names or referring expressions on images. The categories in
Expand Down Expand Up @@ -201,9 +199,7 @@ def owl_v2_image(


def owl_v2_video(
prompt: str,
frames: List[np.ndarray],
box_threshold: float = 0.10,
prompt: str, frames: List[np.ndarray], box_threshold: float = 0.10,
) -> List[List[Dict[str, Any]]]:
"""'owl_v2_video' will run owl_v2 on each frame of a video. It can detect multiple
objects per frame given a text prompt sucha s a category name or referring
Expand Down Expand Up @@ -581,9 +577,7 @@ def loca_visual_prompt_counting(


def countgd_counting(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.23,
prompt: str, image: np.ndarray, box_threshold: float = 0.23,
) -> List[Dict[str, Any]]:
"""'countgd_counting' is a tool that can precisely count multiple instances of an
object given a text prompt. It returns a list of bounding boxes with normalized
Expand Down Expand Up @@ -634,9 +628,7 @@ def countgd_counting(


def countgd_example_based_counting(
visual_prompts: List[List[float]],
image: np.ndarray,
box_threshold: float = 0.23,
visual_prompts: List[List[float]], image: np.ndarray, box_threshold: float = 0.23,
) -> List[Dict[str, Any]]:
"""'countgd_example_based_counting' is a tool that can precisely count multiple
instances of an object given few visual example prompts. It returns a list of bounding
Expand Down Expand Up @@ -1491,7 +1483,7 @@ def closest_box_distance(

horizontal_distance = np.max([0, x21 - x12, x11 - x22])
vertical_distance = np.max([0, y21 - y12, y11 - y22])
return cast(float, np.sqrt(horizontal_distance**2 + vertical_distance**2))
return cast(float, np.sqrt(horizontal_distance ** 2 + vertical_distance ** 2))


# Utility and visualization functions
Expand Down Expand Up @@ -1753,6 +1745,7 @@ def overlay_segmentation_masks(
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
secondary_label_key: str = "tracking_label",
fontsize: Optional[int] = None,
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand All @@ -1761,10 +1754,11 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks, labels and scores.
dictionaries containing the masks and labels.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.
fontsize (Optional[int], optional): The font size to use for the labels in case of needing a custom size.
Returns:
np.ndarray: The image with the masks displayed.
Expand All @@ -1774,7 +1768,6 @@ def overlay_segmentation_masks(
>>> image_with_masks = overlay_segmentation_masks(
image,
[{
'score': 0.99,
'label': 'dinosaur',
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
Expand All @@ -1785,19 +1778,21 @@ def overlay_segmentation_masks(
)
"""
medias_int: List[np.ndarray] = (
[medias] if isinstance(medias, np.ndarray) else medias
[media for media in medias] if isinstance(medias, np.ndarray) else medias
)
masks_int = [masks] if isinstance(masks[0], dict) else masks
masks_int = cast(List[List[Dict[str, Any]]], masks_int)

labels = set()
for mask_i in masks_int:
for mask_j in mask_i:
labels.add(mask_j["label"])
if mask_i is not None:
for mask_j in mask_i:
labels.add(mask_j["label"])
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

width, height = Image.fromarray(medias_int[0]).size
fontsize = max(12, int(min(width, height) / 40))
if fontsize is None:
fontsize = max(12, int(min(width, height) / 40))
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
Expand All @@ -1806,28 +1801,31 @@ def overlay_segmentation_masks(
frame_out = []
for i, frame in enumerate(medias_int):
pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
tracking_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text = tracking_lbl if tracking_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), text, fill="black", font=font)
if masks_int[i] is not None:
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
tracking_lbl = elt.get(secondary_label_key, None)
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label and label is not None:
draw = ImageDraw.Draw(pil_image)
text = tracking_lbl if tracking_lbl else label
text_box = draw.textbbox((0, 0), text=text, font=font)
x, y = get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=text, font=font)
draw.rectangle(
(x, y, text_box[2], text_box[3]), fill=color[label]
)
draw.text((x, y), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out

Expand Down Expand Up @@ -1935,11 +1933,7 @@ def overlay_counting_results(

# Draw the text at the center of the bounding box
draw.text(
(text_x0, text_y0),
label,
fill="black",
font=font,
anchor="lt",
(text_x0, text_y0), label, fill="black", font=font, anchor="lt",
)

return np.array(pil_image)
Expand Down

0 comments on commit ecd8a37

Please sign in to comment.