2727 convert_quad_box_to_bbox ,
2828 convert_to_b64 ,
2929 denormalize_bbox ,
30+ frames_to_bytes ,
3031 get_image_size ,
3132 normalize_bbox ,
3233 numpy_to_bytes ,
@@ -184,10 +185,10 @@ def grounding_sam(
184185 box_threshold : float = 0.20 ,
185186 iou_threshold : float = 0.20 ,
186187) -> List [Dict [str , Any ]]:
187- """'grounding_sam' is a tool that can segment multiple objects given a
188- text prompt such as category names or referring expressions. The categories in text
189- prompt are separated by commas or periods. It returns a list of bounding boxes,
190- label names, mask file names and associated probability scores.
188+ """'grounding_sam' is a tool that can segment multiple objects given a text prompt
189+ such as category names or referring expressions. The categories in text prompt are
190+ separated by commas or periods. It returns a list of bounding boxes, label names ,
191+ mask file names and associated probability scores.
191192
192193 Parameters:
193194 prompt (str): The prompt to ground to the image.
@@ -245,8 +246,8 @@ def grounding_sam(
245246
246247
247248def florence2_sam2_image (prompt : str , image : np .ndarray ) -> List [Dict [str , Any ]]:
248- """'florence2_sam2_image' is a tool that can segment multiple objects given a
249- text prompt such as category names or referring expressions. The categories in text
249+ """'florence2_sam2_image' is a tool that can segment multiple objects given a text
250+ prompt such as category names or referring expressions. The categories in the text
250251 prompt are separated by commas. It returns a list of bounding boxes, label names,
251252 mask file names and associated probability scores.
252253
@@ -297,6 +298,63 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
297298 return return_data
298299
299300
301+ def florence2_sam2_video (
302+ prompt : str , frames : List [np .ndarray ]
303+ ) -> List [List [Dict [str , Any ]]]:
304+ """'florence2_sam2_video' is a tool that can segment and track multiple objects
305+ in a video given a text prompt such as category names or referring expressions. The
306+ categories in the text prompt are separated by commas. It returns tracked objects
307+ as masks, labels, and scores for each frame.
308+
309+ Parameters:
310+ prompt (str): The prompt to ground to the video.
311+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
312+
313+ Returns:
314+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label,
315+ score and mask of the detected objects. The outer list represents each frame
316+ and the inner list is the objects per frame. The label contains the object ID
317+ followed by the label name. The objects are only identified in the first framed
318+ and tracked throughout the video.
319+
320+ Example
321+ -------
322+ >>> florence2_sam2_video("car, dinosaur", frames)
323+ [
324+ [
325+ {
326+ 'label': '0: dinosaur',
327+ 'score': 1.0,
328+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
329+ [0, 0, 0, ..., 0, 0, 0],
330+ ...,
331+ [0, 0, 0, ..., 0, 0, 0],
332+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
333+ },
334+ ],
335+ ]
336+ """
337+
338+ buffer_bytes = frames_to_bytes (frames )
339+ files = [("video" , buffer_bytes )]
340+ payload = {
341+ "prompts" : prompt .split ("," ),
342+ "function_name" : "florence2_sam2_video" ,
343+ }
344+ data : Dict [str , Any ] = send_inference_request (
345+ payload , "florence2-sam2" , files = files , v2 = True
346+ )
347+ return_data = []
348+ for frame_i in data .keys ():
349+ return_frame_data = []
350+ for obj_id , data_j in data [frame_i ].items ():
351+ mask = rle_decode_array (data_j ["mask" ])
352+ label = obj_id + ": " + data_j ["label" ]
353+ return_frame_data .append ({"label" : label , "mask" : mask , "score" : 1.0 })
354+ return_data .append (return_frame_data )
355+ return return_data
356+
357+
300358def extract_frames (
301359 video_uri : Union [str , Path ], fps : float = 0.5
302360) -> List [Tuple [np .ndarray , float ]]:
@@ -1274,15 +1332,43 @@ def overlay_bounding_boxes(
12741332 return np .array (pil_image )
12751333
12761334
1335+ def _get_text_coords_from_mask (
1336+ mask : np .ndarray , v_gap : int = 10 , h_gap : int = 10
1337+ ) -> Tuple [int , int ]:
1338+ mask = mask .astype (np .uint8 )
1339+ if np .sum (mask ) == 0 :
1340+ return (0 , 0 )
1341+
1342+ rows , cols = np .nonzero (mask )
1343+ top = rows .min ()
1344+ bottom = rows .max ()
1345+ left = cols .min ()
1346+ right = cols .max ()
1347+
1348+ if top - v_gap < 0 :
1349+ if bottom + v_gap > mask .shape [0 ]:
1350+ top = top
1351+ else :
1352+ top = bottom + v_gap
1353+ else :
1354+ top = top - v_gap
1355+
1356+ return left + (right - left ) // 2 - h_gap , top
1357+
1358+
12771359def overlay_segmentation_masks (
1278- image : np .ndarray , masks : List [Dict [str , Any ]]
1279- ) -> np .ndarray :
1360+ medias : Union [np .ndarray , List [np .ndarray ]],
1361+ masks : Union [List [Dict [str , Any ]], List [List [Dict [str , Any ]]]],
1362+ draw_label : bool = True ,
1363+ ) -> Union [np .ndarray , List [np .ndarray ]]:
12801364 """'overlay_segmentation_masks' is a utility function that displays segmentation
12811365 masks.
12821366
12831367 Parameters:
1284- image (np.ndarray): The image to display the masks on.
1285- masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
1368+ medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
1369+ the masks on.
1370+ masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
1371+ dictionaries containing the masks.
12861372
12871373 Returns:
12881374 np.ndarray: The image with the masks displayed.
@@ -1302,27 +1388,50 @@ def overlay_segmentation_masks(
13021388 }],
13031389 )
13041390 """
1305- pil_image = Image .fromarray (image .astype (np .uint8 )).convert ("RGBA" )
1391+ medias_int : List [np .ndarray ] = (
1392+ [medias ] if isinstance (medias , np .ndarray ) else medias
1393+ )
1394+ masks_int = [masks ] if isinstance (masks [0 ], dict ) else masks
1395+ masks_int = cast (List [List [Dict [str , Any ]]], masks_int )
13061396
1307- if len (set ([mask ["label" ] for mask in masks ])) > len (COLORS ):
1308- _LOGGER .warning (
1309- "Number of unique labels exceeds the number of available colors. Some labels may have the same color."
1310- )
1397+ labels = set ()
1398+ for mask_i in masks_int :
1399+ for mask_j in mask_i :
1400+ labels .add (mask_j ["label" ])
1401+ color = {label : COLORS [i % len (COLORS )] for i , label in enumerate (labels )}
13111402
1312- color = {
1313- label : COLORS [i % len (COLORS )]
1314- for i , label in enumerate (set ([mask ["label" ] for mask in masks ]))
1315- }
1316- masks = sorted (masks , key = lambda x : x ["label" ], reverse = True )
1403+ width , height = Image .fromarray (medias_int [0 ]).size
1404+ fontsize = max (12 , int (min (width , height ) / 40 ))
1405+ font = ImageFont .truetype (
1406+ str (resources .files ("vision_agent.fonts" ).joinpath ("default_font_ch_en.ttf" )),
1407+ fontsize ,
1408+ )
13171409
1318- for elt in masks :
1319- mask = elt ["mask" ]
1320- label = elt ["label" ]
1321- np_mask = np .zeros ((pil_image .size [1 ], pil_image .size [0 ], 4 ))
1322- np_mask [mask > 0 , :] = color [label ] + (255 * 0.5 ,)
1323- mask_img = Image .fromarray (np_mask .astype (np .uint8 ))
1324- pil_image = Image .alpha_composite (pil_image , mask_img )
1325- return np .array (pil_image )
1410+ frame_out = []
1411+ for i , frame in enumerate (medias_int ):
1412+ pil_image = Image .fromarray (frame .astype (np .uint8 )).convert ("RGBA" )
1413+ for elt in masks_int [i ]:
1414+ mask = elt ["mask" ]
1415+ label = elt ["label" ]
1416+ np_mask = np .zeros ((pil_image .size [1 ], pil_image .size [0 ], 4 ))
1417+ np_mask [mask > 0 , :] = color [label ] + (255 * 0.5 ,)
1418+ mask_img = Image .fromarray (np_mask .astype (np .uint8 ))
1419+ pil_image = Image .alpha_composite (pil_image , mask_img )
1420+
1421+ if draw_label :
1422+ draw = ImageDraw .Draw (pil_image )
1423+ text_box = draw .textbbox ((0 , 0 ), text = label , font = font )
1424+ x , y = _get_text_coords_from_mask (
1425+ mask ,
1426+ v_gap = (text_box [3 ] - text_box [1 ]) + 10 ,
1427+ h_gap = (text_box [2 ] - text_box [0 ]) // 2 ,
1428+ )
1429+ if x != 0 and y != 0 :
1430+ text_box = draw .textbbox ((x , y ), text = label , font = font )
1431+ draw .rectangle ((x , y , text_box [2 ], text_box [3 ]), fill = color [label ])
1432+ draw .text ((x , y ), label , fill = "black" , font = font )
1433+ frame_out .append (np .array (pil_image )) # type: ignore
1434+ return frame_out [0 ] if len (frame_out ) == 1 else frame_out
13261435
13271436
13281437def overlay_heat_map (
0 commit comments