Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

overlay bboxes works with frames #244

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/lmm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

::: vision_agent.lmm.OllamaLMM

::: vision_agent.lmm.ClaudeSonnetLMM
::: vision_agent.lmm.AnthropicLMM
79 changes: 48 additions & 31 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,17 @@ def _save_video_to_result(video_uri: str) -> None:


def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> np.ndarray:
medias: Union[np.ndarray, List[np.ndarray]],
bboxes: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_bounding_boxes' is a utility function that displays bounding boxes on
an image.

Parameters:
image (np.ndarray): The image to display the bounding boxes on.
bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding
medias (Union[np.ndarray, List[np.ndarra]]): The image or frames to display the
bounding boxes on.
bboxes (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries or a list of list of dictionaries containing the bounding
boxes.

Returns:
Expand All @@ -1778,41 +1781,54 @@ def overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if len(set([box["label"] for box in bboxes])) > len(COLORS):
medias_int: List[np.ndarray] = (
[medias] if isinstance(medias, np.ndarray) else medias
)
bbox_int = [bboxes] if isinstance(bboxes[0], dict) else bboxes
bbox_int = cast(List[List[Dict[str, Any]]], bbox_int)
labels = set([bb["label"] for b in bbox_int for bb in b])

if len(labels) > len(COLORS):
_LOGGER.warning(
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
)

color = {
label: COLORS[i % len(COLORS)]
for i, label in enumerate(set([box["label"] for box in bboxes]))
}
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)
frame_out = []
for i, frame in enumerate(medias_int):
pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGB")

for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]
bboxes = bbox_int[i]
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(
resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")
),
fontsize,
)

draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return np.array(pil_image)
for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle(
(box[0], box[1], text_box[2], text_box[3]), fill=color[label]
)
draw.text((box[0], box[1]), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out


def _get_text_coords_from_mask(
Expand Down Expand Up @@ -1852,7 +1868,8 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks, labels and scores.
dictionaries or a list of list of dictionaries containing the masks, labels
and scores.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.
Expand Down
Loading