27
27
convert_quad_box_to_bbox ,
28
28
convert_to_b64 ,
29
29
denormalize_bbox ,
30
+ frames_to_bytes ,
30
31
get_image_size ,
31
32
normalize_bbox ,
32
33
numpy_to_bytes ,
@@ -184,10 +185,10 @@ def grounding_sam(
184
185
box_threshold : float = 0.20 ,
185
186
iou_threshold : float = 0.20 ,
186
187
) -> List [Dict [str , Any ]]:
187
- """'grounding_sam' is a tool that can segment multiple objects given a
188
- text prompt such as category names or referring expressions. The categories in text
189
- prompt are separated by commas or periods. It returns a list of bounding boxes,
190
- label names, mask file names and associated probability scores.
188
+ """'grounding_sam' is a tool that can segment multiple objects given a text prompt
189
+ such as category names or referring expressions. The categories in text prompt are
190
+ separated by commas or periods. It returns a list of bounding boxes, label names ,
191
+ mask file names and associated probability scores.
191
192
192
193
Parameters:
193
194
prompt (str): The prompt to ground to the image.
@@ -245,8 +246,8 @@ def grounding_sam(
245
246
246
247
247
248
def florence2_sam2_image (prompt : str , image : np .ndarray ) -> List [Dict [str , Any ]]:
248
- """'florence2_sam2_image' is a tool that can segment multiple objects given a
249
- text prompt such as category names or referring expressions. The categories in text
249
+ """'florence2_sam2_image' is a tool that can segment multiple objects given a text
250
+ prompt such as category names or referring expressions. The categories in the text
250
251
prompt are separated by commas. It returns a list of bounding boxes, label names,
251
252
mask file names and associated probability scores.
252
253
@@ -297,6 +298,63 @@ def florence2_sam2_image(prompt: str, image: np.ndarray) -> List[Dict[str, Any]]
297
298
return return_data
298
299
299
300
301
+ def florence2_sam2_video (
302
+ prompt : str , frames : List [np .ndarray ]
303
+ ) -> List [List [Dict [str , Any ]]]:
304
+ """'florence2_sam2_video' is a tool that can segment and track multiple objects
305
+ in a video given a text prompt such as category names or referring expressions. The
306
+ categories in the text prompt are separated by commas. It returns tracked objects
307
+ as masks, labels, and scores for each frame.
308
+
309
+ Parameters:
310
+ prompt (str): The prompt to ground to the video.
311
+ frames (List[np.ndarray]): The list of frames to ground the prompt to.
312
+
313
+ Returns:
314
+ List[List[Dict[str, Any]]]: A list of list of dictionaries containing the label,
315
+ score and mask of the detected objects. The outer list represents each frame
316
+ and the inner list is the objects per frame. The label contains the object ID
317
+ followed by the label name. The objects are only identified in the first framed
318
+ and tracked throughout the video.
319
+
320
+ Example
321
+ -------
322
+ >>> florence2_sam2_video("car, dinosaur", frames)
323
+ [
324
+ [
325
+ {
326
+ 'label': '0: dinosaur',
327
+ 'score': 1.0,
328
+ 'mask': array([[0, 0, 0, ..., 0, 0, 0],
329
+ [0, 0, 0, ..., 0, 0, 0],
330
+ ...,
331
+ [0, 0, 0, ..., 0, 0, 0],
332
+ [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
333
+ },
334
+ ],
335
+ ]
336
+ """
337
+
338
+ buffer_bytes = frames_to_bytes (frames )
339
+ files = [("video" , buffer_bytes )]
340
+ payload = {
341
+ "prompts" : prompt .split ("," ),
342
+ "function_name" : "florence2_sam2_video" ,
343
+ }
344
+ data : Dict [str , Any ] = send_inference_request (
345
+ payload , "florence2-sam2" , files = files , v2 = True
346
+ )
347
+ return_data = []
348
+ for frame_i in data .keys ():
349
+ return_frame_data = []
350
+ for obj_id , data_j in data [frame_i ].items ():
351
+ mask = rle_decode_array (data_j ["mask" ])
352
+ label = obj_id + ": " + data_j ["label" ]
353
+ return_frame_data .append ({"label" : label , "mask" : mask , "score" : 1.0 })
354
+ return_data .append (return_frame_data )
355
+ return return_data
356
+
357
+
300
358
def extract_frames (
301
359
video_uri : Union [str , Path ], fps : float = 0.5
302
360
) -> List [Tuple [np .ndarray , float ]]:
@@ -1274,15 +1332,43 @@ def overlay_bounding_boxes(
1274
1332
return np .array (pil_image )
1275
1333
1276
1334
1335
+ def _get_text_coords_from_mask (
1336
+ mask : np .ndarray , v_gap : int = 10 , h_gap : int = 10
1337
+ ) -> Tuple [int , int ]:
1338
+ mask = mask .astype (np .uint8 )
1339
+ if np .sum (mask ) == 0 :
1340
+ return (0 , 0 )
1341
+
1342
+ rows , cols = np .nonzero (mask )
1343
+ top = rows .min ()
1344
+ bottom = rows .max ()
1345
+ left = cols .min ()
1346
+ right = cols .max ()
1347
+
1348
+ if top - v_gap < 0 :
1349
+ if bottom + v_gap > mask .shape [0 ]:
1350
+ top = top
1351
+ else :
1352
+ top = bottom + v_gap
1353
+ else :
1354
+ top = top - v_gap
1355
+
1356
+ return left + (right - left ) // 2 - h_gap , top
1357
+
1358
+
1277
1359
def overlay_segmentation_masks (
1278
- image : np .ndarray , masks : List [Dict [str , Any ]]
1279
- ) -> np .ndarray :
1360
+ medias : Union [np .ndarray , List [np .ndarray ]],
1361
+ masks : Union [List [Dict [str , Any ]], List [List [Dict [str , Any ]]]],
1362
+ draw_label : bool = True ,
1363
+ ) -> Union [np .ndarray , List [np .ndarray ]]:
1280
1364
"""'overlay_segmentation_masks' is a utility function that displays segmentation
1281
1365
masks.
1282
1366
1283
1367
Parameters:
1284
- image (np.ndarray): The image to display the masks on.
1285
- masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
1368
+ medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
1369
+ the masks on.
1370
+ masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
1371
+ dictionaries containing the masks.
1286
1372
1287
1373
Returns:
1288
1374
np.ndarray: The image with the masks displayed.
@@ -1302,27 +1388,50 @@ def overlay_segmentation_masks(
1302
1388
}],
1303
1389
)
1304
1390
"""
1305
- pil_image = Image .fromarray (image .astype (np .uint8 )).convert ("RGBA" )
1391
+ medias_int : List [np .ndarray ] = (
1392
+ [medias ] if isinstance (medias , np .ndarray ) else medias
1393
+ )
1394
+ masks_int = [masks ] if isinstance (masks [0 ], dict ) else masks
1395
+ masks_int = cast (List [List [Dict [str , Any ]]], masks_int )
1306
1396
1307
- if len (set ([mask ["label" ] for mask in masks ])) > len (COLORS ):
1308
- _LOGGER .warning (
1309
- "Number of unique labels exceeds the number of available colors. Some labels may have the same color."
1310
- )
1397
+ labels = set ()
1398
+ for mask_i in masks_int :
1399
+ for mask_j in mask_i :
1400
+ labels .add (mask_j ["label" ])
1401
+ color = {label : COLORS [i % len (COLORS )] for i , label in enumerate (labels )}
1311
1402
1312
- color = {
1313
- label : COLORS [i % len (COLORS )]
1314
- for i , label in enumerate (set ([mask ["label" ] for mask in masks ]))
1315
- }
1316
- masks = sorted (masks , key = lambda x : x ["label" ], reverse = True )
1403
+ width , height = Image .fromarray (medias_int [0 ]).size
1404
+ fontsize = max (12 , int (min (width , height ) / 40 ))
1405
+ font = ImageFont .truetype (
1406
+ str (resources .files ("vision_agent.fonts" ).joinpath ("default_font_ch_en.ttf" )),
1407
+ fontsize ,
1408
+ )
1317
1409
1318
- for elt in masks :
1319
- mask = elt ["mask" ]
1320
- label = elt ["label" ]
1321
- np_mask = np .zeros ((pil_image .size [1 ], pil_image .size [0 ], 4 ))
1322
- np_mask [mask > 0 , :] = color [label ] + (255 * 0.5 ,)
1323
- mask_img = Image .fromarray (np_mask .astype (np .uint8 ))
1324
- pil_image = Image .alpha_composite (pil_image , mask_img )
1325
- return np .array (pil_image )
1410
+ frame_out = []
1411
+ for i , frame in enumerate (medias_int ):
1412
+ pil_image = Image .fromarray (frame .astype (np .uint8 )).convert ("RGBA" )
1413
+ for elt in masks_int [i ]:
1414
+ mask = elt ["mask" ]
1415
+ label = elt ["label" ]
1416
+ np_mask = np .zeros ((pil_image .size [1 ], pil_image .size [0 ], 4 ))
1417
+ np_mask [mask > 0 , :] = color [label ] + (255 * 0.5 ,)
1418
+ mask_img = Image .fromarray (np_mask .astype (np .uint8 ))
1419
+ pil_image = Image .alpha_composite (pil_image , mask_img )
1420
+
1421
+ if draw_label :
1422
+ draw = ImageDraw .Draw (pil_image )
1423
+ text_box = draw .textbbox ((0 , 0 ), text = label , font = font )
1424
+ x , y = _get_text_coords_from_mask (
1425
+ mask ,
1426
+ v_gap = (text_box [3 ] - text_box [1 ]) + 10 ,
1427
+ h_gap = (text_box [2 ] - text_box [0 ]) // 2 ,
1428
+ )
1429
+ if x != 0 and y != 0 :
1430
+ text_box = draw .textbbox ((x , y ), text = label , font = font )
1431
+ draw .rectangle ((x , y , text_box [2 ], text_box [3 ]), fill = color [label ])
1432
+ draw .text ((x , y ), label , fill = "black" , font = font )
1433
+ frame_out .append (np .array (pil_image )) # type: ignore
1434
+ return frame_out [0 ] if len (frame_out ) == 1 else frame_out
1326
1435
1327
1436
1328
1437
def overlay_heat_map (
0 commit comments