Skip to content

Commit 7811bf1

Browse files
authored
Update app.py
1 parent 0b95ece commit 7811bf1

File tree

1 file changed

+44
-47
lines changed

1 file changed

+44
-47
lines changed

app.py

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,24 @@
2020

2121
from sam2.build_sam import build_sam2_video_predictor
2222

23+
parser = argparse.ArgumentParser()
24+
parser.add_argument("--config", type=str, required=True)
25+
parser.add_argument("--model_path", type=str, default="../")
26+
parser.add_argument("--pretrain_model_path", type=str, default="../")
27+
parser.add_argument("--sub_folder", type=str, default="unet")
28+
args = parser.parse_args()
29+
2330
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
2431
model_cfg = "sam2_hiera_l.yaml"
2532

2633
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
2734

28-
config = OmegaConf.load("./configs/code_release.yaml")
29-
validation_pipeline = load_model(model_path="../", \
30-
sub_folder="unet2", \
31-
pretrained_model_path="/data/home/zibojia/StableDiffusion/", \
35+
config = OmegaConf.load(args.config)
36+
validation_pipeline = load_model(model_path=args.model_path, \
37+
sub_folder=args.sub_folder, \
38+
pretrained_model_path=args.pretrain_model_path, \
3239
**config)
3340

34-
video_path = 'images/images.npy'
35-
mask_path = 'images/masks.npy'
36-
validation_images2 = 2*(np.load(video_path)/255.0 - 0.5)
37-
validation_masks2 = np.load(mask_path)/255.0
38-
39-
print('validation_images2', validation_images2.shape)
40-
print('validation_masks2', validation_masks2.shape)
41-
42-
43-
4441
def init_state(
4542
offload_video_to_cpu=False,
4643
offload_state_to_cpu=False
@@ -108,7 +105,7 @@ def get_frames_from_video(video_input, video_state):
108105
frames = vr.get_batch(list(range(len(vr)))).asnumpy()
109106
inference_state = predictor.init_state(images=frames)
110107
fps = 30
111-
image_size = (frames[0].shape[0],frames[0].shape[1])
108+
image_size = (frames[0].shape[0],frames[0].shape[1])
112109
# initialize video_state
113110
video_state = {
114111
"user_name": user_name,
@@ -234,7 +231,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
234231
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
235232
mask = interactive_state["multi_mask"]["masks"][mask_number]
236233
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
237-
234+
238235
operation_log = [("",""), ("Select {} for tracking or inpainting".format(mask_dropdown),"Normal")]
239236
return select_frame, operation_log
240237

@@ -261,7 +258,7 @@ def vos_tracking_video(inference_state, video_state, interactive_state, mask_dro
261258
masks = np.array(masks)
262259

263260
painted_images = None
264-
if interactive_state["track_end_number"]:
261+
if interactive_state["track_end_number"]:
265262
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
266263
org_images = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
267264
color = 255*np.ones((1, org_images.shape[-3], org_images.shape[-2], 3)) * np.array([[[[0,1,1]]]])
@@ -276,10 +273,10 @@ def vos_tracking_video(inference_state, video_state, interactive_state, mask_dro
276273

277274
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=video_state["fps"]) # import video_input to name the output video
278275
interactive_state["inference_times"] += 1
279-
276+
280277
return inference_state, video_output, video_state, interactive_state, operation_log
281278

282-
# inpaint
279+
# inpaint
283280
def inpaint_video(video_state, text_pos_input, text_neg_input, interactive_state, mask_dropdown):
284281
operation_log = [("",""), ("Removed the selected masks.","Normal")]
285282

@@ -337,7 +334,7 @@ def echo_text(text1, text2):
337334

338335
with gr.Blocks() as iface:
339336
"""
340-
state for
337+
state for
341338
"""
342339
click_state = gr.State([[],[]])
343340
interactive_state = gr.State({
@@ -376,20 +373,20 @@ def echo_text(text1, text2):
376373
# for user video input
377374
with gr.Column():
378375
with gr.Row():
379-
video_input = gr.Video(autosize=True)
376+
video_input = gr.Video()#autosize=True)
380377
with gr.Column():
381378
video_info = gr.Textbox(label="Video Info")
382379
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
383380
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
384381
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
385-
382+
386383

387384
with gr.Row():
388385
# put the template frame under the radio button
389386
with gr.Column():
390387
# extract frames
391388
with gr.Column():
392-
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
389+
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
393390

394391
# click points settins, negative or positive, mode continuous or single
395392
with gr.Row():
@@ -400,46 +397,46 @@ def echo_text(text1, text2):
400397
label="Point prompt",
401398
interactive=True,
402399
visible=False)
403-
clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False).style(height=160)
404-
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
400+
clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False)#.style(height=160)
401+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False)#.style(height=360)
405402
with gr.Row():
406403
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False)
407404
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False)
408405
text_pos_input = gr.Textbox(label="Positive Prompt", placeholder="positive prompt...", interactive=True, visible=False)
409406
text_neg_input = gr.Textbox(label="Negative Prompt", placeholder="negative prompt...", interactive=True, visible=False)
410-
407+
411408
with gr.Column():
412409
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=False)
413410
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
414-
video_output = gr.Video(autosize=True, visible=False).style(height=360)
411+
video_output = gr.Video(visible=False)#gr.Video(autosize=True, visible=False)#.style(height=360)
415412
with gr.Row():
416413
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
417414
inpaint_video_predict_button = gr.Button(value="Inpainting", visible=False)
418415

419-
# first step: get the video information
416+
# first step: get the video information
420417
extract_frames_button.click(
421418
fn=get_frames_from_video,
422419
inputs=[
423420
video_input, video_state
424421
],
425422
outputs=[text_pos_input, text_neg_input, inference_state, video_state, video_info, template_frame,
426-
image_selection_slider, track_pause_number_slider,
427-
point_prompt, clear_button_click, template_frame, tracking_video_predict_button,
428-
video_output, mask_dropdown, inpaint_video_predict_button,
423+
image_selection_slider, track_pause_number_slider,
424+
point_prompt, clear_button_click, template_frame, tracking_video_predict_button,
425+
video_output, mask_dropdown, inpaint_video_predict_button,
429426
run_status]
430-
)
427+
)
431428

432429
# second step: select images from slider
433-
image_selection_slider.release(fn=select_template,
434-
inputs=[image_selection_slider, video_state, interactive_state],
430+
image_selection_slider.release(fn=select_template,
431+
inputs=[image_selection_slider, video_state, interactive_state],
435432
outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
436-
track_pause_number_slider.release(fn=get_end_number,
437-
inputs=[track_pause_number_slider, video_state, interactive_state],
433+
track_pause_number_slider.release(fn=get_end_number,
434+
inputs=[track_pause_number_slider, video_state, interactive_state],
438435
outputs=[template_frame, interactive_state, run_status], api_name="end_image")
439-
resize_ratio_slider.release(fn=get_resize_ratio,
440-
inputs=[resize_ratio_slider, interactive_state],
436+
resize_ratio_slider.release(fn=get_resize_ratio,
437+
inputs=[resize_ratio_slider, interactive_state],
441438
outputs=[interactive_state], api_name="resize_ratio")
442-
439+
443440
# click select image to get mask using sam
444441
template_frame.select(
445442
fn=sam_refine,
@@ -467,11 +464,11 @@ def echo_text(text1, text2):
467464
inputs=[video_state, interactive_state, mask_dropdown],
468465
outputs=[template_frame, run_status]
469466
)
470-
467+
471468
# clear input
472469
video_input.clear(
473470
lambda: (
474-
gr.update(visible=False),
471+
gr.update(visible=False),
475472
gr.update(visible=False),
476473
init_state(),
477474
{
@@ -503,10 +500,10 @@ def echo_text(text1, text2):
503500
None,
504501
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
505502
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
506-
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
503+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
507504
),
508505
[],
509-
[
506+
[
510507
text_pos_input,
511508
text_neg_input,
512509
inference_state,
@@ -515,8 +512,8 @@ def echo_text(text1, text2):
515512
click_state,
516513
video_output,
517514
template_frame,
518-
tracking_video_predict_button, image_selection_slider , track_pause_number_slider, point_prompt,
519-
clear_button_click, template_frame, tracking_video_predict_button, video_output,
515+
tracking_video_predict_button, image_selection_slider , track_pause_number_slider, point_prompt,
516+
clear_button_click, template_frame, tracking_video_predict_button, video_output,
520517
mask_dropdown, inpaint_video_predict_button, run_status
521518
],
522519
queue=False,
@@ -528,6 +525,6 @@ def echo_text(text1, text2):
528525
inputs = [inference_state, video_state, click_state],
529526
outputs = [inference_state, template_frame, click_state, run_status],
530527
)
531-
iface.queue(concurrency_count=1)
532-
iface.launch(debug=True, enable_queue=True, server_port=10087, server_name="0.0.0.0")
528+
#iface.queue()#concurrency_count=1)
529+
iface.launch(server_port=8000, server_name="0.0.0.0")
533530
# iface.launch(debug=True, enable_queue=True)

0 commit comments

Comments
 (0)