Skip to content

Commit 9fba11a

Browse files
committed
default mod
1 parent 3227f71 commit 9fba11a

File tree

3 files changed

+56
-48
lines changed

3 files changed

+56
-48
lines changed

example_datasets/zeke.zip

840 KB
Binary file not shown.

train.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train(
4040
),
4141
num_train_epochs: int = Input(
4242
description="Number of epochs to loop through your training dataset",
43-
default=400,
43+
default=4000,
4444
),
4545
max_train_steps: int = Input(
4646
description="Number of individual training steps. Takes precedence over num_train_epochs",
@@ -52,34 +52,23 @@ def train(
5252
# ), # todo.
5353
unet_learning_rate: float = Input(
5454
description="Learning rate for the U-Net. We recommend this value to be somewhere between `1e-6` to `1e-5`.",
55-
default=3e-6,
55+
default=1e-6,
5656
),
5757
ti_learning_rate_multiplier: float = Input(
5858
description="Scaling of learning rate for training textual inversion embeddings. Don't alter unless you know what you're doing.",
59-
default=100,
59+
default=1000,
6060
),
6161
lr_scheduler: str = Input(
6262
description="Learning rate scheduler to use for training",
6363
default="constant",
6464
choices=[
6565
"constant",
6666
"linear",
67-
"cosine",
68-
"cosine_with_restarts",
69-
"polynomial",
70-
"constant_with_warmup",
7167
],
7268
),
7369
lr_warmup_steps: int = Input(
7470
description="Number of warmup steps for lr schedulers with warmups.",
75-
default=500,
76-
),
77-
lr_num_cycles: int = Input(
78-
description="Number of hard restarts used with `cosine_with_restarts` learning rate scheduler",
79-
default=1,
80-
),
81-
lr_power: float = Input(
82-
description="Power for polynomial learning rate scheduler", default=1.0
71+
default=100,
8372
),
8473
token_string: str = Input(
8574
description="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well",
@@ -103,16 +92,20 @@ def train(
10392
),
10493
use_face_detection_instead: bool = Input(
10594
description="If you want to use face detection instead of CLIPSeg for masking. For face applications, we recommend using this option.",
106-
default=False,
95+
default=True,
10796
),
10897
clipseg_temperature: float = Input(
10998
description="How blurry you want the CLIPSeg mask to be. We recommend this value be something between `0.5` to `1.0`. If you want to have more sharp mask (but thus more errorful), you can decrease this value.",
11099
default=1.0,
111100
),
112101
verbose: bool = Input(description="verbose output", default=True),
102+
checkpointing_steps: int = Input(
103+
description="Number of steps between saving checkpoints. Set to very very high number to disable checkpointing, because you don't need one.",
104+
default=200,
105+
),
113106
) -> TrainingOutput:
114107
# Hard-code token_map for now. Make it configurable once we support multiple concepts or user-uploaded caption csv.
115-
token_map = token_string + ":2"
108+
token_map = token_string + ":3"
116109

117110
# Process 'token_to_train' and 'input_data_tar_or_zip'
118111
inserting_list_tokens = token_map.split(",")
@@ -161,15 +154,13 @@ def train(
161154
ti_learning_rate_multiplier=ti_learning_rate_multiplier,
162155
lr_scheduler=lr_scheduler,
163156
lr_warmup_steps=lr_warmup_steps,
164-
lr_num_cycles=lr_num_cycles,
165-
lr_power=lr_power,
166157
token_dict=token_dict,
167158
inserting_list_tokens=all_token_lists,
168159
verbose=verbose,
169160
crops_coords_top_left_h=0,
170161
crops_coords_top_left_w=0,
171162
do_cache=True,
172-
checkpointing_steps=500000,
163+
checkpointing_steps=checkpointing_steps,
173164
scale_lr=False,
174165
dataloader_num_workers=0,
175166
max_grad_norm=1.0,

trainer_pti.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import shutil
77
from typing import Optional
8+
import fnmatch
89

910
import numpy as np
1011
import torch
@@ -13,8 +14,7 @@
1314
from safetensors.torch import save_file
1415
from tqdm.auto import tqdm
1516

16-
from dataset_and_utils import (PreprocessedDataset, TokenEmbeddingsHandler,
17-
load_models)
17+
from dataset_and_utils import PreprocessedDataset, TokenEmbeddingsHandler, load_models
1818

1919

2020
def main(
@@ -49,6 +49,8 @@ def main(
4949
token_dict={"TOKEN": "<s0>"},
5050
inserting_list_tokens=["<s0>"],
5151
verbose: bool = True,
52+
is_lora=False,
53+
lora_rank=32,
5254
) -> None:
5355
if allow_tf32:
5456
torch.backends.cuda.matmul.allow_tf32 = True
@@ -91,15 +93,6 @@ def main(
9193

9294
unet_param_to_optimize = []
9395
# fine tune only attn weights
94-
unet_param_to_optimize_names = []
95-
for name, param in unet.named_parameters():
96-
if "weight" in name and "norm" not in name:
97-
param.requires_grad_(True)
98-
unet_param_to_optimize.append(param)
99-
unet_param_to_optimize_names.append(name)
100-
print(name)
101-
else:
102-
param.requires_grad_(False)
10396

10497
text_encoder_parameters = []
10598
for text_encoder in text_encoders:
@@ -111,23 +104,47 @@ def main(
111104
else:
112105
param.requires_grad = False
113106

114-
# Optimizer creation
115-
params_to_optimize = [
116-
{
117-
"params": unet_param_to_optimize,
118-
"lr": unet_learning_rate,
119-
},
120-
{
121-
"params": text_encoder_parameters,
122-
"lr": ti_learning_rate_multiplier * unet_learning_rate,
123-
"weight_decay": 1e-1,
124-
},
125-
]
126-
127-
optimizer = torch.optim.AdamW(
128-
params_to_optimize,
129-
weight_decay=1e-4,
130-
)
107+
if not is_lora:
108+
WHITELIST_PATTERNS = [
109+
"*.attn*.weight",
110+
"*ff*.weight",
111+
] # TODO : make this a parameter
112+
BLACKLIST_PATTERNS = ["*.norm*.weight"]
113+
114+
unet_param_to_optimize_names = []
115+
for name, param in unet.named_parameters():
116+
if any(
117+
fnmatch.fnmatch(name, pattern) for pattern in WHITELIST_PATTERNS
118+
) and not any(
119+
fnmatch.fnmatch(name, pattern) for pattern in BLACKLIST_PATTERNS
120+
):
121+
param.requires_grad_(True)
122+
unet_param_to_optimize_names.append(name)
123+
print(f"Training: {name}")
124+
else:
125+
param.requires_grad_(False)
126+
127+
# Optimizer creation
128+
params_to_optimize = [
129+
{
130+
"params": unet_param_to_optimize,
131+
"lr": unet_learning_rate,
132+
},
133+
{
134+
"params": text_encoder_parameters,
135+
"lr": ti_learning_rate_multiplier * unet_learning_rate,
136+
"weight_decay": 1e-1,
137+
},
138+
]
139+
140+
optimizer = torch.optim.AdamW(
141+
params_to_optimize,
142+
weight_decay=1e-4,
143+
)
144+
145+
else:
146+
# Do lora-training instead.
147+
pass
131148

132149
print(f"# PTI : Loading dataset, do_cache {do_cache}")
133150

0 commit comments

Comments
 (0)