Skip to content

Commit b78df38

Browse files
authored
Merge pull request #9 from bmaltais/dev
v18.4
2 parents 6987f51 + 69558b5 commit b78df38

File tree

8 files changed

+1709
-1135
lines changed

8 files changed

+1709
-1135
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ Drop by the discord server for support: https://discord.com/channels/10415185624
130130

131131
## Change history
132132

133+
* 12/19 (v18.4) update:
134+
- Add support for shuffle_caption, save_state, resume, prior_loss_weight under "Advanced Configuration" section
135+
- Fix issue with open/save config not working properly
133136
* 12/19 (v18.3) update:
134137
- fix stop encoder training issue
135138
* 12/19 (v18.2) update:

dreambooth_gui.py

Lines changed: 117 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import subprocess
1111
import pathlib
1212
import shutil
13-
from library.dreambooth_folder_creation_gui import gradio_dreambooth_folder_creation_tab
13+
from library.dreambooth_folder_creation_gui import (
14+
gradio_dreambooth_folder_creation_tab,
15+
)
1416
from library.basic_caption_gui import gradio_basic_caption_gui_tab
1517
from library.convert_model_gui import gradio_convert_model_tab
1618
from library.blip_caption_gui import gradio_blip_caption_gui_tab
@@ -20,14 +22,14 @@
2022
get_folder_path,
2123
remove_doublequote,
2224
get_file_path,
23-
get_saveasfile_path
25+
get_saveasfile_path,
2426
)
2527
from easygui import msgbox
2628

2729
folder_symbol = '\U0001f4c2' # 📂
2830
refresh_symbol = '\U0001f504' # 🔄
2931
save_style_symbol = '\U0001f4be' # 💾
30-
document_symbol = '\U0001F4C4' # 📄
32+
document_symbol = '\U0001F4C4' # 📄
3133

3234

3335
def save_configuration(
@@ -60,30 +62,26 @@ def save_configuration(
6062
stop_text_encoder_training,
6163
use_8bit_adam,
6264
xformers,
63-
save_model_as
65+
save_model_as,
66+
shuffle_caption,
67+
save_state,
68+
resume,
69+
prior_loss_weight,
6470
):
6571
original_file_path = file_path
6672

6773
save_as_bool = True if save_as.get('label') == 'True' else False
6874

6975
if save_as_bool:
7076
print('Save as...')
71-
# file_path = filesavebox(
72-
# 'Select the config file to save',
73-
# default='finetune.json',
74-
# filetypes='*.json',
75-
# )
7677
file_path = get_saveasfile_path(file_path)
7778
else:
7879
print('Save...')
7980
if file_path == None or file_path == '':
80-
# file_path = filesavebox(
81-
# 'Select the config file to save',
82-
# default='finetune.json',
83-
# filetypes='*.json',
84-
# )
8581
file_path = get_saveasfile_path(file_path)
8682

83+
# print(file_path)
84+
8785
if file_path == None or file_path == '':
8886
return original_file_path # In case a file_path was provided and the user decide to cancel the open action
8987

@@ -116,7 +114,11 @@ def save_configuration(
116114
'stop_text_encoder_training': stop_text_encoder_training,
117115
'use_8bit_adam': use_8bit_adam,
118116
'xformers': xformers,
119-
'save_model_as': save_model_as
117+
'save_model_as': save_model_as,
118+
'shuffle_caption': shuffle_caption,
119+
'save_state': save_state,
120+
'resume': resume,
121+
'prior_loss_weight': prior_loss_weight,
120122
}
121123

122124
# Save the data to the selected file
@@ -155,14 +157,18 @@ def open_configuration(
155157
stop_text_encoder_training,
156158
use_8bit_adam,
157159
xformers,
158-
save_model_as
160+
save_model_as,
161+
shuffle_caption,
162+
save_state,
163+
resume,
164+
prior_loss_weight,
159165
):
160166

161167
original_file_path = file_path
162168
file_path = get_file_path(file_path)
169+
# print(file_path)
163170

164-
if file_path != '' and file_path != None:
165-
print(file_path)
171+
if not file_path == '' and not file_path == None:
166172
# load variables from JSON file
167173
with open(file_path, 'r') as f:
168174
my_data = json.load(f)
@@ -204,7 +210,11 @@ def open_configuration(
204210
my_data.get('stop_text_encoder_training', stop_text_encoder_training),
205211
my_data.get('use_8bit_adam', use_8bit_adam),
206212
my_data.get('xformers', xformers),
207-
my_data.get('save_model_as', save_model_as)
213+
my_data.get('save_model_as', save_model_as),
214+
my_data.get('shuffle_caption', shuffle_caption),
215+
my_data.get('save_state', save_state),
216+
my_data.get('resume', resume),
217+
my_data.get('prior_loss_weight', prior_loss_weight),
208218
)
209219

210220

@@ -236,7 +246,11 @@ def train_model(
236246
stop_text_encoder_training_pct,
237247
use_8bit_adam,
238248
xformers,
239-
save_model_as
249+
save_model_as,
250+
shuffle_caption,
251+
save_state,
252+
resume,
253+
prior_loss_weight,
240254
):
241255
def save_inference_file(output_dir, v2, v_parameterization):
242256
# Copy inference model for v2 if required
@@ -360,6 +374,10 @@ def save_inference_file(output_dir, v2, v_parameterization):
360374
run_cmd += ' --use_8bit_adam'
361375
if xformers:
362376
run_cmd += ' --xformers'
377+
if shuffle_caption:
378+
run_cmd += ' --shuffle_caption'
379+
if save_state:
380+
run_cmd += ' --save_state'
363381
run_cmd += (
364382
f' --pretrained_model_name_or_path={pretrained_model_name_or_path}'
365383
)
@@ -382,17 +400,23 @@ def save_inference_file(output_dir, v2, v_parameterization):
382400
run_cmd += f' --logging_dir={logging_dir}'
383401
run_cmd += f' --caption_extention={caption_extention}'
384402
if not stop_text_encoder_training == 0:
385-
run_cmd += f' --stop_text_encoder_training={stop_text_encoder_training}'
403+
run_cmd += (
404+
f' --stop_text_encoder_training={stop_text_encoder_training}'
405+
)
386406
if not save_model_as == 'same as source model':
387407
run_cmd += f' --save_model_as={save_model_as}'
408+
if not resume == '':
409+
run_cmd += f' --resume={resume}'
410+
if not float(prior_loss_weight) == 1.0:
411+
run_cmd += f' --prior_loss_weight={prior_loss_weight}'
388412

389413
print(run_cmd)
390414
# Run the command
391415
subprocess.run(run_cmd)
392416

393417
# check if output_dir/last is a folder... therefore it is a diffuser model
394418
last_dir = pathlib.Path(f'{output_dir}/last')
395-
419+
396420
if not last_dir.is_dir():
397421
# Copy inference model for v2 if required
398422
save_inference_file(output_dir, v2, v_parameterization)
@@ -472,8 +496,8 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
472496
)
473497
config_file_name = gr.Textbox(
474498
label='',
475-
# placeholder="type the configuration file path or use the 'Open' button above to select it...",
476-
interactive=False
499+
placeholder="type the configuration file path or use the 'Open' button above to select it...",
500+
interactive=True,
477501
)
478502
# config_file_name.change(
479503
# remove_doublequote,
@@ -491,13 +515,16 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
491515
document_symbol, elem_id='open_folder_small'
492516
)
493517
pretrained_model_name_or_path_fille.click(
494-
get_file_path, inputs=[pretrained_model_name_or_path_input], outputs=pretrained_model_name_or_path_input
518+
get_file_path,
519+
inputs=[pretrained_model_name_or_path_input],
520+
outputs=pretrained_model_name_or_path_input,
495521
)
496522
pretrained_model_name_or_path_folder = gr.Button(
497523
folder_symbol, elem_id='open_folder_small'
498524
)
499525
pretrained_model_name_or_path_folder.click(
500-
get_folder_path, outputs=pretrained_model_name_or_path_input
526+
get_folder_path,
527+
outputs=pretrained_model_name_or_path_input,
501528
)
502529
model_list = gr.Dropdown(
503530
label='(Optional) Model Quick Pick',
@@ -517,10 +544,10 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
517544
'same as source model',
518545
'ckpt',
519546
'diffusers',
520-
"diffusers_safetensors",
547+
'diffusers_safetensors',
521548
'safetensors',
522549
],
523-
value='same as source model'
550+
value='same as source model',
524551
)
525552
with gr.Row():
526553
v2_input = gr.Checkbox(label='v2', value=True)
@@ -607,7 +634,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
607634
)
608635
with gr.Tab('Training parameters'):
609636
with gr.Row():
610-
learning_rate_input = gr.Textbox(label='Learning rate', value=1e-6)
637+
learning_rate_input = gr.Textbox(
638+
label='Learning rate', value=1e-6
639+
)
611640
lr_scheduler_input = gr.Dropdown(
612641
label='LR Scheduler',
613642
choices=[
@@ -662,7 +691,9 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
662691
with gr.Row():
663692
seed_input = gr.Textbox(label='Seed', value=1234)
664693
max_resolution_input = gr.Textbox(
665-
label='Max resolution', value='512,512', placeholder='512,512'
694+
label='Max resolution',
695+
value='512,512',
696+
placeholder='512,512',
666697
)
667698
with gr.Row():
668699
caption_extention_input = gr.Textbox(
@@ -676,27 +707,45 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
676707
step=1,
677708
label='Stop text encoder training',
678709
)
679-
with gr.Row():
680-
full_fp16_input = gr.Checkbox(
681-
label='Full fp16 training (experimental)', value=False
682-
)
683-
no_token_padding_input = gr.Checkbox(
684-
label='No token padding', value=False
685-
)
686-
687-
gradient_checkpointing_input = gr.Checkbox(
688-
label='Gradient checkpointing', value=False
689-
)
690710
with gr.Row():
691711
enable_bucket_input = gr.Checkbox(
692712
label='Enable buckets', value=True
693713
)
694-
cache_latent_input = gr.Checkbox(label='Cache latent', value=True)
714+
cache_latent_input = gr.Checkbox(
715+
label='Cache latent', value=True
716+
)
695717
use_8bit_adam_input = gr.Checkbox(
696718
label='Use 8bit adam', value=True
697719
)
698720
xformers_input = gr.Checkbox(label='Use xformers', value=True)
699-
721+
with gr.Accordion('Advanced Configuration', open=False):
722+
with gr.Row():
723+
full_fp16_input = gr.Checkbox(
724+
label='Full fp16 training (experimental)', value=False
725+
)
726+
no_token_padding_input = gr.Checkbox(
727+
label='No token padding', value=False
728+
)
729+
730+
gradient_checkpointing_input = gr.Checkbox(
731+
label='Gradient checkpointing', value=False
732+
)
733+
734+
shuffle_caption = gr.Checkbox(
735+
label='Shuffle caption', value=False
736+
)
737+
save_state = gr.Checkbox(label='Save state', value=False)
738+
with gr.Row():
739+
resume = gr.Textbox(
740+
label='Resume',
741+
placeholder='path to "last-state" state folder to resume from',
742+
)
743+
resume_button = gr.Button('📂', elem_id='open_folder_small')
744+
resume_button.click(get_folder_path, outputs=resume)
745+
prior_loss_weight = gr.Number(
746+
label='Prior loss weight', value=1.0
747+
)
748+
700749
button_run = gr.Button('Train model')
701750

702751
with gr.Tab('Utilities'):
@@ -713,8 +762,6 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
713762
gradio_dataset_balancing_tab()
714763
gradio_convert_model_tab()
715764

716-
717-
718765
button_open_config.click(
719766
open_configuration,
720767
inputs=[
@@ -746,7 +793,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
746793
stop_text_encoder_training_input,
747794
use_8bit_adam_input,
748795
xformers_input,
749-
save_model_as_dropdown
796+
save_model_as_dropdown,
797+
shuffle_caption,
798+
save_state,
799+
resume,
800+
prior_loss_weight,
750801
],
751802
outputs=[
752803
config_file_name,
@@ -777,7 +828,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
777828
stop_text_encoder_training_input,
778829
use_8bit_adam_input,
779830
xformers_input,
780-
save_model_as_dropdown
831+
save_model_as_dropdown,
832+
shuffle_caption,
833+
save_state,
834+
resume,
835+
prior_loss_weight,
781836
],
782837
)
783838

@@ -815,7 +870,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
815870
stop_text_encoder_training_input,
816871
use_8bit_adam_input,
817872
xformers_input,
818-
save_model_as_dropdown
873+
save_model_as_dropdown,
874+
shuffle_caption,
875+
save_state,
876+
resume,
877+
prior_loss_weight,
819878
],
820879
outputs=[config_file_name],
821880
)
@@ -852,7 +911,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
852911
stop_text_encoder_training_input,
853912
use_8bit_adam_input,
854913
xformers_input,
855-
save_model_as_dropdown
914+
save_model_as_dropdown,
915+
shuffle_caption,
916+
save_state,
917+
resume,
918+
prior_loss_weight,
856919
],
857920
outputs=[config_file_name],
858921
)
@@ -887,7 +950,11 @@ def set_pretrained_model_name_or_path_input(value, v2, v_parameterization):
887950
stop_text_encoder_training_input,
888951
use_8bit_adam_input,
889952
xformers_input,
890-
save_model_as_dropdown
953+
save_model_as_dropdown,
954+
shuffle_caption,
955+
save_state,
956+
resume,
957+
prior_loss_weight,
891958
],
892959
)
893960

0 commit comments

Comments
 (0)