10
10
import subprocess
11
11
import pathlib
12
12
import shutil
13
- from library .dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
13
+ from library .dreambooth_folder_creation_gui import (
14
+ gradio_dreambooth_folder_creation_tab ,
15
+ )
14
16
from library .basic_caption_gui import gradio_basic_caption_gui_tab
15
17
from library .convert_model_gui import gradio_convert_model_tab
16
18
from library .blip_caption_gui import gradio_blip_caption_gui_tab
20
22
get_folder_path ,
21
23
remove_doublequote ,
22
24
get_file_path ,
23
- get_saveasfile_path
25
+ get_saveasfile_path ,
24
26
)
25
27
from easygui import msgbox
26
28
27
29
folder_symbol = '\U0001f4c2 ' # 📂
28
30
refresh_symbol = '\U0001f504 ' # 🔄
29
31
save_style_symbol = '\U0001f4be ' # 💾
30
- document_symbol = '\U0001F4C4 ' # 📄
32
+ document_symbol = '\U0001F4C4 ' # 📄
31
33
32
34
33
35
def save_configuration (
@@ -60,30 +62,26 @@ def save_configuration(
60
62
stop_text_encoder_training ,
61
63
use_8bit_adam ,
62
64
xformers ,
63
- save_model_as
65
+ save_model_as ,
66
+ shuffle_caption ,
67
+ save_state ,
68
+ resume ,
69
+ prior_loss_weight ,
64
70
):
65
71
original_file_path = file_path
66
72
67
73
save_as_bool = True if save_as .get ('label' ) == 'True' else False
68
74
69
75
if save_as_bool :
70
76
print ('Save as...' )
71
- # file_path = filesavebox(
72
- # 'Select the config file to save',
73
- # default='finetune.json',
74
- # filetypes='*.json',
75
- # )
76
77
file_path = get_saveasfile_path (file_path )
77
78
else :
78
79
print ('Save...' )
79
80
if file_path == None or file_path == '' :
80
- # file_path = filesavebox(
81
- # 'Select the config file to save',
82
- # default='finetune.json',
83
- # filetypes='*.json',
84
- # )
85
81
file_path = get_saveasfile_path (file_path )
86
82
83
+ # print(file_path)
84
+
87
85
if file_path == None or file_path == '' :
88
86
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
89
87
@@ -116,7 +114,11 @@ def save_configuration(
116
114
'stop_text_encoder_training' : stop_text_encoder_training ,
117
115
'use_8bit_adam' : use_8bit_adam ,
118
116
'xformers' : xformers ,
119
- 'save_model_as' : save_model_as
117
+ 'save_model_as' : save_model_as ,
118
+ 'shuffle_caption' : shuffle_caption ,
119
+ 'save_state' : save_state ,
120
+ 'resume' : resume ,
121
+ 'prior_loss_weight' : prior_loss_weight ,
120
122
}
121
123
122
124
# Save the data to the selected file
@@ -155,14 +157,18 @@ def open_configuration(
155
157
stop_text_encoder_training ,
156
158
use_8bit_adam ,
157
159
xformers ,
158
- save_model_as
160
+ save_model_as ,
161
+ shuffle_caption ,
162
+ save_state ,
163
+ resume ,
164
+ prior_loss_weight ,
159
165
):
160
166
161
167
original_file_path = file_path
162
168
file_path = get_file_path (file_path )
169
+ # print(file_path)
163
170
164
- if file_path != '' and file_path != None :
165
- print (file_path )
171
+ if not file_path == '' and not file_path == None :
166
172
# load variables from JSON file
167
173
with open (file_path , 'r' ) as f :
168
174
my_data = json .load (f )
@@ -204,7 +210,11 @@ def open_configuration(
204
210
my_data .get ('stop_text_encoder_training' , stop_text_encoder_training ),
205
211
my_data .get ('use_8bit_adam' , use_8bit_adam ),
206
212
my_data .get ('xformers' , xformers ),
207
- my_data .get ('save_model_as' , save_model_as )
213
+ my_data .get ('save_model_as' , save_model_as ),
214
+ my_data .get ('shuffle_caption' , shuffle_caption ),
215
+ my_data .get ('save_state' , save_state ),
216
+ my_data .get ('resume' , resume ),
217
+ my_data .get ('prior_loss_weight' , prior_loss_weight ),
208
218
)
209
219
210
220
@@ -236,7 +246,11 @@ def train_model(
236
246
stop_text_encoder_training_pct ,
237
247
use_8bit_adam ,
238
248
xformers ,
239
- save_model_as
249
+ save_model_as ,
250
+ shuffle_caption ,
251
+ save_state ,
252
+ resume ,
253
+ prior_loss_weight ,
240
254
):
241
255
def save_inference_file (output_dir , v2 , v_parameterization ):
242
256
# Copy inference model for v2 if required
@@ -360,6 +374,10 @@ def save_inference_file(output_dir, v2, v_parameterization):
360
374
run_cmd += ' --use_8bit_adam'
361
375
if xformers :
362
376
run_cmd += ' --xformers'
377
+ if shuffle_caption :
378
+ run_cmd += ' --shuffle_caption'
379
+ if save_state :
380
+ run_cmd += ' --save_state'
363
381
run_cmd += (
364
382
f' --pretrained_model_name_or_path={ pretrained_model_name_or_path } '
365
383
)
@@ -382,17 +400,23 @@ def save_inference_file(output_dir, v2, v_parameterization):
382
400
run_cmd += f' --logging_dir={ logging_dir } '
383
401
run_cmd += f' --caption_extention={ caption_extention } '
384
402
if not stop_text_encoder_training == 0 :
385
- run_cmd += f' --stop_text_encoder_training={ stop_text_encoder_training } '
403
+ run_cmd += (
404
+ f' --stop_text_encoder_training={ stop_text_encoder_training } '
405
+ )
386
406
if not save_model_as == 'same as source model' :
387
407
run_cmd += f' --save_model_as={ save_model_as } '
408
+ if not resume == '' :
409
+ run_cmd += f' --resume={ resume } '
410
+ if not float (prior_loss_weight ) == 1.0 :
411
+ run_cmd += f' --prior_loss_weight={ prior_loss_weight } '
388
412
389
413
print (run_cmd )
390
414
# Run the command
391
415
subprocess .run (run_cmd )
392
416
393
417
# check if output_dir/last is a folder... therefore it is a diffuser model
394
418
last_dir = pathlib .Path (f'{ output_dir } /last' )
395
-
419
+
396
420
if not last_dir .is_dir ():
397
421
# Copy inference model for v2 if required
398
422
save_inference_file (output_dir , v2 , v_parameterization )
@@ -472,8 +496,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
472
496
)
473
497
config_file_name = gr .Textbox (
474
498
label = '' ,
475
- # placeholder="type the configuration file path or use the 'Open' button above to select it...",
476
- interactive = False
499
+ placeholder = "type the configuration file path or use the 'Open' button above to select it..." ,
500
+ interactive = True ,
477
501
)
478
502
# config_file_name.change(
479
503
# remove_doublequote,
@@ -491,13 +515,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
491
515
document_symbol , elem_id = 'open_folder_small'
492
516
)
493
517
pretrained_model_name_or_path_fille .click (
494
- get_file_path , inputs = [pretrained_model_name_or_path_input ], outputs = pretrained_model_name_or_path_input
518
+ get_file_path ,
519
+ inputs = [pretrained_model_name_or_path_input ],
520
+ outputs = pretrained_model_name_or_path_input ,
495
521
)
496
522
pretrained_model_name_or_path_folder = gr .Button (
497
523
folder_symbol , elem_id = 'open_folder_small'
498
524
)
499
525
pretrained_model_name_or_path_folder .click (
500
- get_folder_path , outputs = pretrained_model_name_or_path_input
526
+ get_folder_path ,
527
+ outputs = pretrained_model_name_or_path_input ,
501
528
)
502
529
model_list = gr .Dropdown (
503
530
label = '(Optional) Model Quick Pick' ,
@@ -517,10 +544,10 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
517
544
'same as source model' ,
518
545
'ckpt' ,
519
546
'diffusers' ,
520
- " diffusers_safetensors" ,
547
+ ' diffusers_safetensors' ,
521
548
'safetensors' ,
522
549
],
523
- value = 'same as source model'
550
+ value = 'same as source model' ,
524
551
)
525
552
with gr .Row ():
526
553
v2_input = gr .Checkbox (label = 'v2' , value = True )
@@ -607,7 +634,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
607
634
)
608
635
with gr .Tab ('Training parameters' ):
609
636
with gr .Row ():
610
- learning_rate_input = gr .Textbox (label = 'Learning rate' , value = 1e-6 )
637
+ learning_rate_input = gr .Textbox (
638
+ label = 'Learning rate' , value = 1e-6
639
+ )
611
640
lr_scheduler_input = gr .Dropdown (
612
641
label = 'LR Scheduler' ,
613
642
choices = [
@@ -662,7 +691,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
662
691
with gr .Row ():
663
692
seed_input = gr .Textbox (label = 'Seed' , value = 1234 )
664
693
max_resolution_input = gr .Textbox (
665
- label = 'Max resolution' , value = '512,512' , placeholder = '512,512'
694
+ label = 'Max resolution' ,
695
+ value = '512,512' ,
696
+ placeholder = '512,512' ,
666
697
)
667
698
with gr .Row ():
668
699
caption_extention_input = gr .Textbox (
@@ -676,27 +707,45 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
676
707
step = 1 ,
677
708
label = 'Stop text encoder training' ,
678
709
)
679
- with gr .Row ():
680
- full_fp16_input = gr .Checkbox (
681
- label = 'Full fp16 training (experimental)' , value = False
682
- )
683
- no_token_padding_input = gr .Checkbox (
684
- label = 'No token padding' , value = False
685
- )
686
-
687
- gradient_checkpointing_input = gr .Checkbox (
688
- label = 'Gradient checkpointing' , value = False
689
- )
690
710
with gr .Row ():
691
711
enable_bucket_input = gr .Checkbox (
692
712
label = 'Enable buckets' , value = True
693
713
)
694
- cache_latent_input = gr .Checkbox (label = 'Cache latent' , value = True )
714
+ cache_latent_input = gr .Checkbox (
715
+ label = 'Cache latent' , value = True
716
+ )
695
717
use_8bit_adam_input = gr .Checkbox (
696
718
label = 'Use 8bit adam' , value = True
697
719
)
698
720
xformers_input = gr .Checkbox (label = 'Use xformers' , value = True )
699
-
721
+ with gr .Accordion ('Advanced Configuration' , open = False ):
722
+ with gr .Row ():
723
+ full_fp16_input = gr .Checkbox (
724
+ label = 'Full fp16 training (experimental)' , value = False
725
+ )
726
+ no_token_padding_input = gr .Checkbox (
727
+ label = 'No token padding' , value = False
728
+ )
729
+
730
+ gradient_checkpointing_input = gr .Checkbox (
731
+ label = 'Gradient checkpointing' , value = False
732
+ )
733
+
734
+ shuffle_caption = gr .Checkbox (
735
+ label = 'Shuffle caption' , value = False
736
+ )
737
+ save_state = gr .Checkbox (label = 'Save state' , value = False )
738
+ with gr .Row ():
739
+ resume = gr .Textbox (
740
+ label = 'Resume' ,
741
+ placeholder = 'path to "last-state" state folder to resume from' ,
742
+ )
743
+ resume_button = gr .Button ('📂' , elem_id = 'open_folder_small' )
744
+ resume_button .click (get_folder_path , outputs = resume )
745
+ prior_loss_weight = gr .Number (
746
+ label = 'Prior loss weight' , value = 1.0
747
+ )
748
+
700
749
button_run = gr .Button ('Train model' )
701
750
702
751
with gr .Tab ('Utilities' ):
@@ -713,8 +762,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
713
762
gradio_dataset_balancing_tab ()
714
763
gradio_convert_model_tab ()
715
764
716
-
717
-
718
765
button_open_config .click (
719
766
open_configuration ,
720
767
inputs = [
@@ -746,7 +793,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
746
793
stop_text_encoder_training_input ,
747
794
use_8bit_adam_input ,
748
795
xformers_input ,
749
- save_model_as_dropdown
796
+ save_model_as_dropdown ,
797
+ shuffle_caption ,
798
+ save_state ,
799
+ resume ,
800
+ prior_loss_weight ,
750
801
],
751
802
outputs = [
752
803
config_file_name ,
@@ -777,7 +828,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
777
828
stop_text_encoder_training_input ,
778
829
use_8bit_adam_input ,
779
830
xformers_input ,
780
- save_model_as_dropdown
831
+ save_model_as_dropdown ,
832
+ shuffle_caption ,
833
+ save_state ,
834
+ resume ,
835
+ prior_loss_weight ,
781
836
],
782
837
)
783
838
@@ -815,7 +870,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
815
870
stop_text_encoder_training_input ,
816
871
use_8bit_adam_input ,
817
872
xformers_input ,
818
- save_model_as_dropdown
873
+ save_model_as_dropdown ,
874
+ shuffle_caption ,
875
+ save_state ,
876
+ resume ,
877
+ prior_loss_weight ,
819
878
],
820
879
outputs = [config_file_name ],
821
880
)
@@ -852,7 +911,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
852
911
stop_text_encoder_training_input ,
853
912
use_8bit_adam_input ,
854
913
xformers_input ,
855
- save_model_as_dropdown
914
+ save_model_as_dropdown ,
915
+ shuffle_caption ,
916
+ save_state ,
917
+ resume ,
918
+ prior_loss_weight ,
856
919
],
857
920
outputs = [config_file_name ],
858
921
)
@@ -887,7 +950,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
887
950
stop_text_encoder_training_input ,
888
951
use_8bit_adam_input ,
889
952
xformers_input ,
890
- save_model_as_dropdown
953
+ save_model_as_dropdown ,
954
+ shuffle_caption ,
955
+ save_state ,
956
+ resume ,
957
+ prior_loss_weight ,
891
958
],
892
959
)
893
960
0 commit comments