Skip to content

Commit 188edd3

Browse files
committed
v12 release
1 parent 5a68cd0 commit 188edd3

File tree

2 files changed

+61
-39
lines changed

2 files changed

+61
-39
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,8 @@ options:
390390
- Fixed a bug where prior_loss_weight was applied to learning images. Sorry for the inconvenience.
391391
- Compatible with Stable Diffusion v2.0. Add the `--v2` option. If you are using `768-v-ema.ckpt` or `stable-diffusion-2` instead of `stable-diffusion-v2-base`, add `--v_parameterization` as well. Learn more about other options.
392392
- Added options related to the learning rate scheduler.
393-
- You can download and use DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training.
393+
- You can download and use DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training.
394+
* 11/29 (v12) update:
395+
- stop training text encoder at specified step (`--stop_text_encoder_training=<step #>`)
396+
- tqdm smoothing
397+
- updated fine tuning script to support SD2.0 768/v

train_db_fixed.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
# v8: supports Diffusers 0.7.2
77
# v9: add bucketing option
88
# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
9-
# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization
9+
# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization
1010
# add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface
1111
# support save_ever_n_epochs/save_state in DiffUsers model
1212
# fix the issue that prior_loss_weight is applyed to train images
13+
# v12: stop train text encode, tqdm smoothing
1314

1415
import time
1516
from torch.autograd.function import Function
@@ -39,33 +40,6 @@
3940
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
4041
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
4142

42-
# DiffUsers版StableDiffusionのモデルパラメータ
43-
NUM_TRAIN_TIMESTEPS = 1000
44-
BETA_START = 0.00085
45-
BETA_END = 0.0120
46-
47-
UNET_PARAMS_MODEL_CHANNELS = 320
48-
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
49-
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
50-
UNET_PARAMS_IMAGE_SIZE = 32 # unused
51-
UNET_PARAMS_IN_CHANNELS = 4
52-
UNET_PARAMS_OUT_CHANNELS = 4
53-
UNET_PARAMS_NUM_RES_BLOCKS = 2
54-
UNET_PARAMS_CONTEXT_DIM = 768
55-
UNET_PARAMS_NUM_HEADS = 8
56-
57-
VAE_PARAMS_Z_CHANNELS = 4
58-
VAE_PARAMS_RESOLUTION = 256
59-
VAE_PARAMS_IN_CHANNELS = 3
60-
VAE_PARAMS_OUT_CH = 3
61-
VAE_PARAMS_CH = 128
62-
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
63-
VAE_PARAMS_NUM_RES_BLOCKS = 2
64-
65-
# V2
66-
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
67-
V2_UNET_PARAMS_CONTEXT_DIM = 1024
68-
6943
# checkpointファイル名
7044
LAST_CHECKPOINT_NAME = "last.ckpt"
7145
LAST_STATE_NAME = "last-state"
@@ -693,6 +667,34 @@ def forward_xformers(self, x, context=None, mask=None):
693667

694668
# region checkpoint変換、読み込み、書き込み ###############################
695669

670+
# DiffUsers版StableDiffusionのモデルパラメータ
671+
NUM_TRAIN_TIMESTEPS = 1000
672+
BETA_START = 0.00085
673+
BETA_END = 0.0120
674+
675+
UNET_PARAMS_MODEL_CHANNELS = 320
676+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
677+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
678+
UNET_PARAMS_IMAGE_SIZE = 32 # unused
679+
UNET_PARAMS_IN_CHANNELS = 4
680+
UNET_PARAMS_OUT_CHANNELS = 4
681+
UNET_PARAMS_NUM_RES_BLOCKS = 2
682+
UNET_PARAMS_CONTEXT_DIM = 768
683+
UNET_PARAMS_NUM_HEADS = 8
684+
685+
VAE_PARAMS_Z_CHANNELS = 4
686+
VAE_PARAMS_RESOLUTION = 256
687+
VAE_PARAMS_IN_CHANNELS = 3
688+
VAE_PARAMS_OUT_CH = 3
689+
VAE_PARAMS_CH = 128
690+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
691+
VAE_PARAMS_NUM_RES_BLOCKS = 2
692+
693+
# V2
694+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
695+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
696+
697+
696698
# region StableDiffusion->Diffusersの変換コード
697699
# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
698700

@@ -1408,9 +1410,13 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
14081410
return checkpoint
14091411

14101412

1411-
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path):
1413+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
14121414
checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
14131415
state_dict = checkpoint["state_dict"]
1416+
if dtype is not None:
1417+
for k, v in state_dict.items():
1418+
if type(v) is torch.Tensor:
1419+
state_dict[k] = v.to(dtype)
14141420

14151421
# Convert the UNet2DConditionModel model.
14161422
unet_config = create_unet_diffusers_config(v2)
@@ -1854,10 +1860,15 @@ def load_dreambooth_dir(dir):
18541860
print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
18551861
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
18561862

1857-
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
1863+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
18581864
global_step = 0
18591865

1860-
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
1866+
# v12で更新:clip_sample=Falseに
1867+
# Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
1868+
# 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
1869+
# よくソースを見たら学習時は関係ないや(;'∀') 
1870+
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
1871+
num_train_timesteps=1000, clip_sample=False)
18611872

18621873
if accelerator.is_main_process:
18631874
accelerator.init_trackers("dreambooth")
@@ -1891,13 +1902,16 @@ def load_dreambooth_dir(dir):
18911902
# (this is the forward diffusion process)
18921903
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
18931904

1894-
# Get the text embedding for conditioning
1895-
if args.clip_skip is None:
1896-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1897-
else:
1898-
enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
1899-
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1900-
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
1905+
# 指定したステップ数までText Encoderを学習する
1906+
train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training
1907+
with torch.set_grad_enabled(train_text_encoder):
1908+
# Get the text embedding for conditioning
1909+
if args.clip_skip is None:
1910+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1911+
else:
1912+
enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
1913+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
1914+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
19011915

19021916
# Predict the noise residual
19031917
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -1954,6 +1968,9 @@ def load_dreambooth_dir(dir):
19541968
progress_bar.update(1)
19551969
global_step += 1
19561970

1971+
if global_step == args.stop_text_encoder_training:
1972+
print(f"stop text encoder training at step {global_step}")
1973+
19571974
current_loss = loss.detach().item()
19581975
if args.logging_dir is not None:
19591976
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
@@ -2052,6 +2069,7 @@ def load_dreambooth_dir(dir):
20522069
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
20532070
parser.add_argument("--no_token_padding", action="store_true",
20542071
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
2072+
parser.add_argument("--stop_text_encoder_training", type=int, default=None, help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
20552073
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
20562074
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
20572075
parser.add_argument("--face_crop_aug_range", type=str, default=None,

0 commit comments

Comments
 (0)