Skip to content

Commit

Permalink
added florence2+sam2 for video
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Aug 21, 2024
1 parent 9ce9ec8 commit e24762e
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 28 deletions.
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
florence2_ocr,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
Expand Down
165 changes: 137 additions & 28 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
convert_quad_box_to_bbox,
convert_to_b64,
denormalize_bbox,
frames_to_bytes,
get_image_size,
normalize_bbox,
numpy_to_bytes,
Expand Down Expand Up @@ -184,10 +185,10 @@ def grounding_sam(
box_threshold: float = 0.20,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_sam' is a tool that can segment multiple objects given a
text prompt such as category names or referring expressions. The categories in text
prompt are separated by commas or periods. It returns a list of bounding boxes,
label names, mask file names and associated probability scores.
"""'grounding_sam' is a tool that can segment multiple objects given a text prompt
such as category names or referring expressions. The categories in text prompt are
separated by commas or periods. It returns a list of bounding boxes, label names,
mask file names and associated probability scores.
Parameters:
prompt (str): The prompt to ground to the image.
Expand Down Expand Up @@ -245,8 +246,8 @@ def grounding_sam(


def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]:
"""'florence2_sam2_image' is a tool that can segment multiple objects given a
text prompt such as category names or referring expressions. The categories in text
"""'florence2_sam2_image' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores.
Expand Down Expand Up @@ -297,6 +298,63 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
return return_data


def florence2_sam2_video(
prompt: str, frames: List[np.ndarray]
) -> List[List[Dict[str, Any]]]:
"""'florence2_sam2_video' is a tool that can segment and track multiple objects
in a video given a text prompt such as category names or referring expressions. The
categories in the text prompt are separated by commas. It returns tracked objects
as masks, labels, and scores for each frame.
Parameters:
prompt (str): The prompt to ground to the video.
frames (List[np.ndarray]): The list of frames to ground the prompt to.
Returns:
List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label,
score and mask of the detected objects. The outer list represents each frame
and the inner list is the objects per frame. The label contains the object ID
followed by the label name. The objects are only identified in the first framed
and tracked throughout the video.
Example
-------
>>> florence2_sam2_video("car, dinosaur", frames)
[
[
{
'label': '0: dinosaur',
'score': 1.0,
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
],
]
"""

buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {
"prompts": prompt.split(","),
"function_name": "florence2_sam2_video",
}
data: Dict[str, Any] = send_inference_request(
payload, "florence2-sam2", files=files, v2=True
)
return_data = []
for frame_i in data.keys():
return_frame_data = []
for obj_id, data_j in data[frame_i].items():
mask = rle_decode_array(data_j["mask"])
label = obj_id + ": " + data_j["label"]
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
return_data.append(return_frame_data)
return return_data


def extract_frames(
video_uri: Union[str, Path], fps: float = 0.5
) -> List[Tuple[np.ndarray, float]]:
Expand Down Expand Up @@ -1274,15 +1332,43 @@ def overlay_bounding_boxes(
return np.array(pil_image)


def _get_text_coords_from_mask(
mask: np.ndarray, v_gap: int = 10, h_gap: int = 10
) -> Tuple[int, int]:
mask = mask.astype(np.uint8)
if np.sum(mask) == 0:
return (0, 0)

rows, cols = np.nonzero(mask)
top = rows.min()
bottom = rows.max()
left = cols.min()
right = cols.max()

if top - v_gap < 0:
if bottom + v_gap > mask.shape[0]:
top = top
else:
top = bottom + v_gap
else:
top = top - v_gap

return left + (right - left) // 2 - h_gap, top


def overlay_segmentation_masks(
image: np.ndarray, masks: List[Dict[str, Any]]
) -> np.ndarray:
medias: Union[np.ndarray, List[np.ndarray]],
masks: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
draw_label: bool = True,
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Parameters:
image (np.ndarray): The image to display the masks on.
masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks.
Returns:
np.ndarray: The image with the masks displayed.
Expand All @@ -1302,27 +1388,50 @@ def overlay_segmentation_masks(
}],
)
"""
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGBA")
medias_int: List[np.ndarray] = (
[medias] if isinstance(medias, np.ndarray) else medias
)
masks_int = [masks] if isinstance(masks[0], dict) else masks
masks_int = cast(List[List[Dict[str, Any]]], masks_int)

if len(set([mask["label"] for mask in masks])) > len(COLORS):
_LOGGER.warning(
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
)
labels = set()
for mask_i in masks_int:
for mask_j in mask_i:
labels.add(mask_j["label"])
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

color = {
label: COLORS[i % len(COLORS)]
for i, label in enumerate(set([mask["label"] for mask in masks]))
}
masks = sorted(masks, key=lambda x: x["label"], reverse=True)
width, height = Image.fromarray(medias_int[0]).size
fontsize = max(12, int(min(width, height) / 40))
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)

for elt in masks:
mask = elt["mask"]
label = elt["label"]
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)
return np.array(pil_image)
frame_out = []
for i, frame in enumerate(medias_int):
pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGBA")
for elt in masks_int[i]:
mask = elt["mask"]
label = elt["label"]
np_mask = np.zeros((pil_image.size[1], pil_image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)

if draw_label:
draw = ImageDraw.Draw(pil_image)
text_box = draw.textbbox((0, 0), text=label, font=font)
x, y = _get_text_coords_from_mask(
mask,
v_gap=(text_box[3] - text_box[1]) + 10,
h_gap=(text_box[2] - text_box[0]) // 2,
)
if x != 0 and y != 0:
text_box = draw.textbbox((x, y), text=label, font=font)
draw.rectangle((x, y, text_box[2], text_box[3]), fill=color[label])
draw.text((x, y), label, fill="black", font=font)
frame_out.append(np.array(pil_image)) # type: ignore
return frame_out[0] if len(frame_out) == 1 else frame_out


def overlay_heat_map(
Expand Down
20 changes: 20 additions & 0 deletions vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import base64
import io
import tempfile
from importlib import resources
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Tuple, Union

import numpy as np
from moviepy.editor import ImageSequenceClip
from PIL import Image, ImageDraw, ImageFont
from PIL.Image import Image as ImageType

Expand Down Expand Up @@ -86,6 +88,24 @@ def rle_decode_array(rle: Dict[str, List[int]]) -> np.ndarray:
return binary_mask


def frames_to_bytes(
frames: List[np.ndarray], fps: float = 10, file_ext: str = "mp4"
) -> bytes:
r"""Convert a list of frames to a video file encoded into a byte string.
Parameters:
frames: the list of frames
fps: the frames per second of the video
file_ext: the file extension of the video file
"""
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
clip = ImageSequenceClip(frames, fps=fps)
clip.write_videofile(temp_file.name + f".{file_ext}", fps=fps)
with open(temp_file.name + f".{file_ext}", "rb") as f:
buffer_bytes = f.read()
return buffer_bytes


def b64_to_pil(b64_str: str) -> ImageType:
r"""Convert a base64 string to a PIL Image.
Expand Down

0 comments on commit e24762e

Please sign in to comment.