|
6 | 6 | # v8: supports Diffusers 0.7.2
|
7 | 7 | # v9: add bucketing option
|
8 | 8 | # v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
|
9 |
| -# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization |
| 9 | +# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization |
10 | 10 | # add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface
|
11 | 11 | # support save_ever_n_epochs/save_state in DiffUsers model
|
12 | 12 | # fix the issue that prior_loss_weight is applyed to train images
|
| 13 | +# v12: stop train text encode, tqdm smoothing |
13 | 14 |
|
14 | 15 | import time
|
15 | 16 | from torch.autograd.function import Function
|
|
39 | 40 | TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
40 | 41 | V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
41 | 42 |
|
42 |
| -# DiffUsers版StableDiffusionのモデルパラメータ |
43 |
| -NUM_TRAIN_TIMESTEPS = 1000 |
44 |
| -BETA_START = 0.00085 |
45 |
| -BETA_END = 0.0120 |
46 |
| - |
47 |
| -UNET_PARAMS_MODEL_CHANNELS = 320 |
48 |
| -UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] |
49 |
| -UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] |
50 |
| -UNET_PARAMS_IMAGE_SIZE = 32 # unused |
51 |
| -UNET_PARAMS_IN_CHANNELS = 4 |
52 |
| -UNET_PARAMS_OUT_CHANNELS = 4 |
53 |
| -UNET_PARAMS_NUM_RES_BLOCKS = 2 |
54 |
| -UNET_PARAMS_CONTEXT_DIM = 768 |
55 |
| -UNET_PARAMS_NUM_HEADS = 8 |
56 |
| - |
57 |
| -VAE_PARAMS_Z_CHANNELS = 4 |
58 |
| -VAE_PARAMS_RESOLUTION = 256 |
59 |
| -VAE_PARAMS_IN_CHANNELS = 3 |
60 |
| -VAE_PARAMS_OUT_CH = 3 |
61 |
| -VAE_PARAMS_CH = 128 |
62 |
| -VAE_PARAMS_CH_MULT = [1, 2, 4, 4] |
63 |
| -VAE_PARAMS_NUM_RES_BLOCKS = 2 |
64 |
| - |
65 |
| -# V2 |
66 |
| -V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] |
67 |
| -V2_UNET_PARAMS_CONTEXT_DIM = 1024 |
68 |
| - |
69 | 43 | # checkpointファイル名
|
70 | 44 | LAST_CHECKPOINT_NAME = "last.ckpt"
|
71 | 45 | LAST_STATE_NAME = "last-state"
|
@@ -693,6 +667,34 @@ def forward_xformers(self, x, context=None, mask=None):
|
693 | 667 |
|
694 | 668 | # region checkpoint変換、読み込み、書き込み ###############################
|
695 | 669 |
|
| 670 | +# DiffUsers版StableDiffusionのモデルパラメータ |
| 671 | +NUM_TRAIN_TIMESTEPS = 1000 |
| 672 | +BETA_START = 0.00085 |
| 673 | +BETA_END = 0.0120 |
| 674 | + |
| 675 | +UNET_PARAMS_MODEL_CHANNELS = 320 |
| 676 | +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] |
| 677 | +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] |
| 678 | +UNET_PARAMS_IMAGE_SIZE = 32 # unused |
| 679 | +UNET_PARAMS_IN_CHANNELS = 4 |
| 680 | +UNET_PARAMS_OUT_CHANNELS = 4 |
| 681 | +UNET_PARAMS_NUM_RES_BLOCKS = 2 |
| 682 | +UNET_PARAMS_CONTEXT_DIM = 768 |
| 683 | +UNET_PARAMS_NUM_HEADS = 8 |
| 684 | + |
| 685 | +VAE_PARAMS_Z_CHANNELS = 4 |
| 686 | +VAE_PARAMS_RESOLUTION = 256 |
| 687 | +VAE_PARAMS_IN_CHANNELS = 3 |
| 688 | +VAE_PARAMS_OUT_CH = 3 |
| 689 | +VAE_PARAMS_CH = 128 |
| 690 | +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] |
| 691 | +VAE_PARAMS_NUM_RES_BLOCKS = 2 |
| 692 | + |
| 693 | +# V2 |
| 694 | +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] |
| 695 | +V2_UNET_PARAMS_CONTEXT_DIM = 1024 |
| 696 | + |
| 697 | + |
696 | 698 | # region StableDiffusion->Diffusersの変換コード
|
697 | 699 | # convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
|
698 | 700 |
|
@@ -1408,9 +1410,13 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
1408 | 1410 | return checkpoint
|
1409 | 1411 |
|
1410 | 1412 |
|
1411 |
| -def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path): |
| 1413 | +def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): |
1412 | 1414 | checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1413 | 1415 | state_dict = checkpoint["state_dict"]
|
| 1416 | + if dtype is not None: |
| 1417 | + for k, v in state_dict.items(): |
| 1418 | + if type(v) is torch.Tensor: |
| 1419 | + state_dict[k] = v.to(dtype) |
1414 | 1420 |
|
1415 | 1421 | # Convert the UNet2DConditionModel model.
|
1416 | 1422 | unet_config = create_unet_diffusers_config(v2)
|
@@ -1854,10 +1860,15 @@ def load_dreambooth_dir(dir):
|
1854 | 1860 | print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
|
1855 | 1861 | print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
1856 | 1862 |
|
1857 |
| - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps") |
| 1863 | + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") |
1858 | 1864 | global_step = 0
|
1859 | 1865 |
|
1860 |
| - noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) |
| 1866 | + # v12で更新:clip_sample=Falseに |
| 1867 | + # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる |
| 1868 | + # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ |
| 1869 | + # よくソースを見たら学習時は関係ないや(;'∀') |
| 1870 | + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", |
| 1871 | + num_train_timesteps=1000, clip_sample=False) |
1861 | 1872 |
|
1862 | 1873 | if accelerator.is_main_process:
|
1863 | 1874 | accelerator.init_trackers("dreambooth")
|
@@ -1891,13 +1902,16 @@ def load_dreambooth_dir(dir):
|
1891 | 1902 | # (this is the forward diffusion process)
|
1892 | 1903 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
1893 | 1904 |
|
1894 |
| - # Get the text embedding for conditioning |
1895 |
| - if args.clip_skip is None: |
1896 |
| - encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
1897 |
| - else: |
1898 |
| - enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) |
1899 |
| - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] |
1900 |
| - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) |
| 1905 | + # 指定したステップ数までText Encoderを学習する |
| 1906 | + train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training |
| 1907 | + with torch.set_grad_enabled(train_text_encoder): |
| 1908 | + # Get the text embedding for conditioning |
| 1909 | + if args.clip_skip is None: |
| 1910 | + encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 1911 | + else: |
| 1912 | + enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) |
| 1913 | + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] |
| 1914 | + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) |
1901 | 1915 |
|
1902 | 1916 | # Predict the noise residual
|
1903 | 1917 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
@@ -1954,6 +1968,9 @@ def load_dreambooth_dir(dir):
|
1954 | 1968 | progress_bar.update(1)
|
1955 | 1969 | global_step += 1
|
1956 | 1970 |
|
| 1971 | + if global_step == args.stop_text_encoder_training: |
| 1972 | + print(f"stop text encoder training at step {global_step}") |
| 1973 | + |
1957 | 1974 | current_loss = loss.detach().item()
|
1958 | 1975 | if args.logging_dir is not None:
|
1959 | 1976 | logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
@@ -2052,6 +2069,7 @@ def load_dreambooth_dir(dir):
|
2052 | 2069 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
|
2053 | 2070 | parser.add_argument("--no_token_padding", action="store_true",
|
2054 | 2071 | help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
| 2072 | + parser.add_argument("--stop_text_encoder_training", type=int, default=None, help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") |
2055 | 2073 | parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
2056 | 2074 | parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
2057 | 2075 | parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
|
0 commit comments