20
20
21
21
from sam2 .build_sam import build_sam2_video_predictor
22
22
23
+ parser = argparse .ArgumentParser ()
24
+ parser .add_argument ("--config" , type = str , required = True )
25
+ parser .add_argument ("--model_path" , type = str , default = "../" )
26
+ parser .add_argument ("--pretrain_model_path" , type = str , default = "../" )
27
+ parser .add_argument ("--sub_folder" , type = str , default = "unet" )
28
+ args = parser .parse_args ()
29
+
23
30
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
24
31
model_cfg = "sam2_hiera_l.yaml"
25
32
26
33
predictor = build_sam2_video_predictor (model_cfg , sam2_checkpoint )
27
34
28
- config = OmegaConf .load ("./configs/code_release.yaml" )
29
- validation_pipeline = load_model (model_path = "../" , \
30
- sub_folder = "unet2" , \
31
- pretrained_model_path = "/data/home/zibojia/StableDiffusion/" , \
35
+ config = OmegaConf .load (args . config )
36
+ validation_pipeline = load_model (model_path = args . model_path , \
37
+ sub_folder = args . sub_folder , \
38
+ pretrained_model_path = args . pretrain_model_path , \
32
39
** config )
33
40
34
- video_path = 'images/images.npy'
35
- mask_path = 'images/masks.npy'
36
- validation_images2 = 2 * (np .load (video_path )/ 255.0 - 0.5 )
37
- validation_masks2 = np .load (mask_path )/ 255.0
38
-
39
- print ('validation_images2' , validation_images2 .shape )
40
- print ('validation_masks2' , validation_masks2 .shape )
41
-
42
-
43
-
44
41
def init_state (
45
42
offload_video_to_cpu = False ,
46
43
offload_state_to_cpu = False
@@ -108,7 +105,7 @@ def get_frames_from_video(video_input, video_state):
108
105
frames = vr .get_batch (list (range (len (vr )))).asnumpy ()
109
106
inference_state = predictor .init_state (images = frames )
110
107
fps = 30
111
- image_size = (frames [0 ].shape [0 ],frames [0 ].shape [1 ])
108
+ image_size = (frames [0 ].shape [0 ],frames [0 ].shape [1 ])
112
109
# initialize video_state
113
110
video_state = {
114
111
"user_name" : user_name ,
@@ -234,7 +231,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
234
231
mask_number = int (mask_dropdown [i ].split ("_" )[1 ]) - 1
235
232
mask = interactive_state ["multi_mask" ]["masks" ][mask_number ]
236
233
select_frame = mask_painter (select_frame , mask .astype ('uint8' ), mask_color = mask_number + 2 )
237
-
234
+
238
235
operation_log = [("" ,"" ), ("Select {} for tracking or inpainting" .format (mask_dropdown ),"Normal" )]
239
236
return select_frame , operation_log
240
237
@@ -261,7 +258,7 @@ def vos_tracking_video(inference_state, video_state, interactive_state, mask_dro
261
258
masks = np .array (masks )
262
259
263
260
painted_images = None
264
- if interactive_state ["track_end_number" ]:
261
+ if interactive_state ["track_end_number" ]:
265
262
video_state ["masks" ][video_state ["select_frame_number" ]:interactive_state ["track_end_number" ]] = masks
266
263
org_images = video_state ["origin_images" ][video_state ["select_frame_number" ]:interactive_state ["track_end_number" ]]
267
264
color = 255 * np .ones ((1 , org_images .shape [- 3 ], org_images .shape [- 2 ], 3 )) * np .array ([[[[0 ,1 ,1 ]]]])
@@ -276,10 +273,10 @@ def vos_tracking_video(inference_state, video_state, interactive_state, mask_dro
276
273
277
274
video_output = generate_video_from_frames (video_state ["painted_images" ], output_path = "./result/track/{}" .format (video_state ["video_name" ]), fps = video_state ["fps" ]) # import video_input to name the output video
278
275
interactive_state ["inference_times" ] += 1
279
-
276
+
280
277
return inference_state , video_output , video_state , interactive_state , operation_log
281
278
282
- # inpaint
279
+ # inpaint
283
280
def inpaint_video (video_state , text_pos_input , text_neg_input , interactive_state , mask_dropdown ):
284
281
operation_log = [("" ,"" ), ("Removed the selected masks." ,"Normal" )]
285
282
@@ -337,7 +334,7 @@ def echo_text(text1, text2):
337
334
338
335
with gr .Blocks () as iface :
339
336
"""
340
- state for
337
+ state for
341
338
"""
342
339
click_state = gr .State ([[],[]])
343
340
interactive_state = gr .State ({
@@ -376,20 +373,20 @@ def echo_text(text1, text2):
376
373
# for user video input
377
374
with gr .Column ():
378
375
with gr .Row ():
379
- video_input = gr .Video (autosize = True )
376
+ video_input = gr .Video () # autosize=True)
380
377
with gr .Column ():
381
378
video_info = gr .Textbox (label = "Video Info" )
382
379
resize_info = gr .Textbox (value = "If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
383
380
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing." , label = "Tips for running this demo." )
384
381
resize_ratio_slider = gr .Slider (minimum = 0.02 , maximum = 1 , step = 0.02 , value = 1 , label = "Resize ratio" , visible = True )
385
-
382
+
386
383
387
384
with gr .Row ():
388
385
# put the template frame under the radio button
389
386
with gr .Column ():
390
387
# extract frames
391
388
with gr .Column ():
392
- extract_frames_button = gr .Button (value = "Get video info" , interactive = True , variant = "primary" )
389
+ extract_frames_button = gr .Button (value = "Get video info" , interactive = True , variant = "primary" )
393
390
394
391
# click points settins, negative or positive, mode continuous or single
395
392
with gr .Row ():
@@ -400,46 +397,46 @@ def echo_text(text1, text2):
400
397
label = "Point prompt" ,
401
398
interactive = True ,
402
399
visible = False )
403
- clear_button_click = gr .Button (value = "Clear clicks" , interactive = True , visible = False ).style (height = 160 )
404
- template_frame = gr .Image (type = "pil" ,interactive = True , elem_id = "template_frame" , visible = False ).style (height = 360 )
400
+ clear_button_click = gr .Button (value = "Clear clicks" , interactive = True , visible = False )# .style(height=160)
401
+ template_frame = gr .Image (type = "pil" ,interactive = True , elem_id = "template_frame" , visible = False )# .style(height=360)
405
402
with gr .Row ():
406
403
image_selection_slider = gr .Slider (minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = "Track start frame" , visible = False )
407
404
track_pause_number_slider = gr .Slider (minimum = 1 , maximum = 100 , step = 1 , value = 1 , label = "Track end frame" , visible = False )
408
405
text_pos_input = gr .Textbox (label = "Positive Prompt" , placeholder = "positive prompt..." , interactive = True , visible = False )
409
406
text_neg_input = gr .Textbox (label = "Negative Prompt" , placeholder = "negative prompt..." , interactive = True , visible = False )
410
-
407
+
411
408
with gr .Column ():
412
409
run_status = gr .HighlightedText (value = [("Text" ,"Error" ),("to be" ,"Label 2" ),("highlighted" ,"Label 3" )], visible = False )
413
410
mask_dropdown = gr .Dropdown (multiselect = True , value = [], label = "Mask selection" , info = "." , visible = False )
414
- video_output = gr .Video (autosize = True , visible = False ).style (height = 360 )
411
+ video_output = gr .Video (visible = False ) #gr.Video( autosize=True, visible=False)# .style(height=360)
415
412
with gr .Row ():
416
413
tracking_video_predict_button = gr .Button (value = "Tracking" , visible = False )
417
414
inpaint_video_predict_button = gr .Button (value = "Inpainting" , visible = False )
418
415
419
- # first step: get the video information
416
+ # first step: get the video information
420
417
extract_frames_button .click (
421
418
fn = get_frames_from_video ,
422
419
inputs = [
423
420
video_input , video_state
424
421
],
425
422
outputs = [text_pos_input , text_neg_input , inference_state , video_state , video_info , template_frame ,
426
- image_selection_slider , track_pause_number_slider ,
427
- point_prompt , clear_button_click , template_frame , tracking_video_predict_button ,
428
- video_output , mask_dropdown , inpaint_video_predict_button ,
423
+ image_selection_slider , track_pause_number_slider ,
424
+ point_prompt , clear_button_click , template_frame , tracking_video_predict_button ,
425
+ video_output , mask_dropdown , inpaint_video_predict_button ,
429
426
run_status ]
430
- )
427
+ )
431
428
432
429
# second step: select images from slider
433
- image_selection_slider .release (fn = select_template ,
434
- inputs = [image_selection_slider , video_state , interactive_state ],
430
+ image_selection_slider .release (fn = select_template ,
431
+ inputs = [image_selection_slider , video_state , interactive_state ],
435
432
outputs = [template_frame , video_state , interactive_state , run_status ], api_name = "select_image" )
436
- track_pause_number_slider .release (fn = get_end_number ,
437
- inputs = [track_pause_number_slider , video_state , interactive_state ],
433
+ track_pause_number_slider .release (fn = get_end_number ,
434
+ inputs = [track_pause_number_slider , video_state , interactive_state ],
438
435
outputs = [template_frame , interactive_state , run_status ], api_name = "end_image" )
439
- resize_ratio_slider .release (fn = get_resize_ratio ,
440
- inputs = [resize_ratio_slider , interactive_state ],
436
+ resize_ratio_slider .release (fn = get_resize_ratio ,
437
+ inputs = [resize_ratio_slider , interactive_state ],
441
438
outputs = [interactive_state ], api_name = "resize_ratio" )
442
-
439
+
443
440
# click select image to get mask using sam
444
441
template_frame .select (
445
442
fn = sam_refine ,
@@ -467,11 +464,11 @@ def echo_text(text1, text2):
467
464
inputs = [video_state , interactive_state , mask_dropdown ],
468
465
outputs = [template_frame , run_status ]
469
466
)
470
-
467
+
471
468
# clear input
472
469
video_input .clear (
473
470
lambda : (
474
- gr .update (visible = False ),
471
+ gr .update (visible = False ),
475
472
gr .update (visible = False ),
476
473
init_state (),
477
474
{
@@ -503,10 +500,10 @@ def echo_text(text1, text2):
503
500
None ,
504
501
gr .update (visible = False ), gr .update (visible = False ), gr .update (visible = False ), gr .update (visible = False ), \
505
502
gr .update (visible = False ), gr .update (visible = False ), gr .update (visible = False , value = []), gr .update (visible = False ), \
506
- gr .update (visible = False ), gr .update (visible = False ), gr .update (visible = False )
503
+ gr .update (visible = False ), gr .update (visible = False ), gr .update (visible = False )
507
504
),
508
505
[],
509
- [
506
+ [
510
507
text_pos_input ,
511
508
text_neg_input ,
512
509
inference_state ,
@@ -515,8 +512,8 @@ def echo_text(text1, text2):
515
512
click_state ,
516
513
video_output ,
517
514
template_frame ,
518
- tracking_video_predict_button , image_selection_slider , track_pause_number_slider , point_prompt ,
519
- clear_button_click , template_frame , tracking_video_predict_button , video_output ,
515
+ tracking_video_predict_button , image_selection_slider , track_pause_number_slider , point_prompt ,
516
+ clear_button_click , template_frame , tracking_video_predict_button , video_output ,
520
517
mask_dropdown , inpaint_video_predict_button , run_status
521
518
],
522
519
queue = False ,
@@ -528,6 +525,6 @@ def echo_text(text1, text2):
528
525
inputs = [inference_state , video_state , click_state ],
529
526
outputs = [inference_state , template_frame , click_state , run_status ],
530
527
)
531
- iface .queue (concurrency_count = 1 )
532
- iface .launch (debug = True , enable_queue = True , server_port = 10087 , server_name = "0.0.0.0" )
528
+ # iface.queue()# concurrency_count=1)
529
+ iface .launch (server_port = 8000 , server_name = "0.0.0.0" )
533
530
# iface.launch(debug=True, enable_queue=True)
0 commit comments