27
27
from toolkit .config import get_config
28
28
29
29
from caption import Captioner
30
+ from wandb_client import WeightsAndBiasesClient
30
31
32
+
33
+ JOB_NAME = "flux_train_replicate"
31
34
WEIGHTS_PATH = Path ("./FLUX.1-dev" )
32
35
INPUT_DIR = Path ("input_images" )
33
36
OUTPUT_DIR = Path ("output" )
37
+ JOB_DIR = OUTPUT_DIR / JOB_NAME
34
38
35
39
36
40
class CustomSDTrainer (SDTrainer ):
41
+ def __init__ (self , * args , ** kwargs ):
42
+ super ().__init__ (* args , ** kwargs )
43
+ self .seen_samples = set ()
44
+ self .wandb : WeightsAndBiasesClient | None = None
45
+
37
46
def hook_train_loop (self , batch ):
38
- # TODO: Add W&B logging, etc.
39
- return super ().hook_train_loop (batch )
47
+ loss_dict = super ().hook_train_loop (batch )
48
+ if self .wandb :
49
+ self .wandb .log_loss (loss_dict , self .step_num )
50
+ return loss_dict
51
+
52
+ def sample (self , step = None , is_first = False ):
53
+ super ().sample (step = step , is_first = is_first )
54
+ output_dir = JOB_DIR / "samples"
55
+ all_samples = set ([p .name for p in output_dir .glob ("*.jpg" )])
56
+ new_samples = all_samples - self .seen_samples
57
+ if self .wandb :
58
+ image_paths = [output_dir / p for p in sorted (new_samples )]
59
+ self .wandb .log_samples (image_paths , step )
60
+ self .seen_samples = all_samples
61
+
62
+ def post_save_hook (self , save_path ):
63
+ super ().post_save_hook (save_path )
64
+ # final lora path
65
+ lora_path = JOB_DIR / f"{ JOB_NAME } .safetensors"
66
+ if not lora_path .exists ():
67
+ # intermediate saved weights
68
+ lora_path = sorted (JOB_DIR .glob ("*.safetensors" ))[- 1 ]
69
+ if self .wandb :
70
+ print (f"Saving weights to W&B: { lora_path .name } " )
71
+ self .wandb .save_weights (lora_path )
40
72
41
73
42
74
class CustomJob (BaseJob ):
43
- def __init__ (self , config : OrderedDict ):
75
+ def __init__ (
76
+ self , config : OrderedDict , wandb_client : WeightsAndBiasesClient | None
77
+ ):
44
78
super ().__init__ (config )
45
79
self .device = self .get_conf ("device" , "cpu" )
46
80
self .process_dict = {"custom_sd_trainer" : CustomSDTrainer }
47
81
self .load_processes (self .process_dict )
82
+ for process in self .process :
83
+ process .wandb = wandb_client
48
84
49
85
def run (self ):
50
86
super ().run ()
@@ -82,7 +118,7 @@ def train(
82
118
),
83
119
steps : int = Input (
84
120
description = "Number of training steps. Recommended range 500-4000" ,
85
- ge = 10 ,
121
+ ge = 3 ,
86
122
le = 6000 ,
87
123
default = 1000 ,
88
124
),
@@ -120,6 +156,36 @@ def train(
120
156
description = "Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face." ,
121
157
default = None ,
122
158
),
159
+ wandb_api_key : Secret = Input (
160
+ description = "Weights and Biases API key, if you'd like to log training progress to W&B." ,
161
+ default = None ,
162
+ ),
163
+ wandb_project : str = Input (
164
+ description = "Weights and Biases project name. Only applicable if wandb_api_key is set." ,
165
+ default = JOB_NAME ,
166
+ ),
167
+ wandb_run : str = Input (
168
+ description = "Weights and Biases run name. Only applicable if wandb_api_key is set." ,
169
+ default = None ,
170
+ ),
171
+ wandb_entity : str = Input (
172
+ description = "Weights and Biases entity name. Only applicable if wandb_api_key is set." ,
173
+ default = None ,
174
+ ),
175
+ wandb_sample_interval : int = Input (
176
+ description = "Step interval for sampling output images that are logged to W&B. Only applicable if wandb_api_key is set." ,
177
+ default = 100 ,
178
+ ge = 1 ,
179
+ ),
180
+ wandb_sample_prompts : str = Input (
181
+ description = "Semicolon-separated list of prompts to use when logging samples to W&B. Only applicable if wandb_api_key is set." ,
182
+ default = None ,
183
+ ),
184
+ wandb_save_interval : int = Input (
185
+ description = "Step interval for saving intermediate LoRA weights to W&B. Only applicable if wandb_api_key is set." ,
186
+ default = 100 ,
187
+ ge = 1 ,
188
+ ),
123
189
skip_training_and_use_pretrained_hf_lora_url : str = Input (
124
190
description = "If you’d like to skip LoRA training altogether and instead create a Replicate model from a pre-trained LoRA that’s on HuggingFace, use this field with a HuggingFace download URL. For example, https://huggingface.co/fofr/flux-80s-cyberpunk/resolve/main/lora.safetensors." ,
125
191
default = None ,
@@ -136,14 +202,42 @@ def train(
136
202
if not input_images :
137
203
raise ValueError ("input_images must be provided" )
138
204
205
+ sample_prompts = []
206
+ if wandb_sample_prompts :
207
+ sample_prompts = [p .strip () for p in wandb_sample_prompts .split (";" )]
208
+
209
+ wandb_client = None
210
+ if wandb_api_key :
211
+ wandb_config = {
212
+ "trigger_word" : trigger_word ,
213
+ "autocaption" : autocaption ,
214
+ "autocaption_prefix" : autocaption_prefix ,
215
+ "autocaption_suffix" : autocaption_suffix ,
216
+ "steps" : steps ,
217
+ "learning_rate" : learning_rate ,
218
+ "batch_size" : batch_size ,
219
+ "resolution" : resolution ,
220
+ "lora_rank" : lora_rank ,
221
+ "caption_dropout_rate" : caption_dropout_rate ,
222
+ "optimizer" : optimizer ,
223
+ }
224
+ wandb_client = WeightsAndBiasesClient (
225
+ api_key = wandb_api_key .get_secret_value (),
226
+ config = wandb_config ,
227
+ sample_prompts = sample_prompts ,
228
+ project = wandb_project ,
229
+ entity = wandb_entity ,
230
+ name = wandb_run ,
231
+ )
232
+
139
233
download_weights ()
140
234
extract_zip (input_images , INPUT_DIR )
141
235
142
236
train_config = OrderedDict (
143
237
{
144
238
"job" : "custom_job" ,
145
239
"config" : {
146
- "name" : "flux_train_replicate" ,
240
+ "name" : JOB_NAME ,
147
241
"process" : [
148
242
{
149
243
"type" : "custom_sd_trainer" ,
@@ -157,7 +251,9 @@ def train(
157
251
},
158
252
"save" : {
159
253
"dtype" : "float16" ,
160
- "save_every" : steps + 1 ,
254
+ "save_every" : wandb_save_interval
255
+ if wandb_api_key
256
+ else steps + 1 ,
161
257
"max_step_saves_to_keep" : 1 ,
162
258
},
163
259
"datasets" : [
@@ -166,6 +262,7 @@ def train(
166
262
"caption_ext" : "txt" ,
167
263
"caption_dropout_rate" : caption_dropout_rate ,
168
264
"shuffle_tokens" : False ,
265
+ # TODO: Do we need to cache to disk? It's faster not to.
169
266
"cache_latents_to_disk" : True ,
170
267
"resolution" : [
171
268
int (res ) for res in resolution .split ("," )
@@ -193,15 +290,17 @@ def train(
193
290
},
194
291
"sample" : {
195
292
"sampler" : "flowmatch" ,
196
- "sample_every" : steps + 1 ,
293
+ "sample_every" : wandb_sample_interval
294
+ if wandb_api_key and sample_prompts
295
+ else steps + 1 ,
197
296
"width" : 1024 ,
198
297
"height" : 1024 ,
199
- "prompts" : [] ,
298
+ "prompts" : sample_prompts ,
200
299
"neg" : "" ,
201
300
"seed" : 42 ,
202
301
"walk_seed" : True ,
203
- "guidance_scale" : 4 ,
204
- "sample_steps" : 20 ,
302
+ "guidance_scale" : 3.5 ,
303
+ "sample_steps" : 28 ,
205
304
},
206
305
}
207
306
],
@@ -222,39 +321,52 @@ def train(
222
321
torch .cuda .empty_cache ()
223
322
224
323
print ("Starting train job" )
225
- job = CustomJob (get_config (train_config , name = None ))
324
+ job = CustomJob (get_config (train_config , name = None ), wandb_client )
226
325
job .run ()
326
+
327
+ if wandb_client :
328
+ wandb_client .finish ()
329
+
227
330
job .cleanup ()
228
331
229
- lora_dir = OUTPUT_DIR / "flux_train_replicate"
230
- lora_file = lora_dir / "flux_train_replicate.safetensors"
231
- lora_file .rename (lora_dir / "lora.safetensors" )
332
+ lora_file = JOB_DIR / f"{ JOB_NAME } .safetensors"
333
+ lora_file .rename (JOB_DIR / "lora.safetensors" )
334
+
335
+ samples_dir = JOB_DIR / "samples"
336
+ if samples_dir .exists ():
337
+ shutil .rmtree (samples_dir )
338
+
339
+ # Remove any intermediate lora paths
340
+ lora_paths = JOB_DIR .glob ("*.safetensors" )
341
+ for path in lora_paths :
342
+ if path .name != "lora.safetensors" :
343
+ path .unlink ()
232
344
233
345
# Optimizer is used to continue training, not needed in output
234
- optimizer_file = lora_dir / "optimizer.pt"
346
+ optimizer_file = JOB_DIR / "optimizer.pt"
235
347
if optimizer_file .exists ():
236
348
optimizer_file .unlink ()
237
349
238
350
# Copy generated captions to the output tar
239
351
# But do not upload publicly to HF
240
- captions_dir = lora_dir / "captions"
352
+ captions_dir = JOB_DIR / "captions"
241
353
captions_dir .mkdir (exist_ok = True )
242
354
for caption_file in INPUT_DIR .glob ("*.txt" ):
243
355
shutil .copy (caption_file , captions_dir )
244
356
245
- os .system (f"tar -cvf { output_path } { lora_dir } " )
357
+ os .system (f"tar -cvf { output_path } { JOB_DIR } " )
246
358
247
359
if hf_token is not None and hf_repo_id is not None :
248
360
if captions_dir .exists ():
249
361
shutil .rmtree (captions_dir )
250
362
251
363
try :
252
- handle_hf_readme (lora_dir , hf_repo_id , trigger_word )
364
+ handle_hf_readme (hf_repo_id , trigger_word )
253
365
print (f"Uploading to Hugging Face: { hf_repo_id } " )
254
366
api = HfApi ()
255
367
api .upload_folder (
256
368
repo_id = hf_repo_id ,
257
- folder_path = lora_dir ,
369
+ folder_path = JOB_DIR ,
258
370
repo_type = "model" ,
259
371
use_auth_token = hf_token .get_secret_value (),
260
372
)
@@ -264,8 +376,8 @@ def train(
264
376
return TrainingOutput (weights = Path (output_path ))
265
377
266
378
267
- def handle_hf_readme (lora_dir : Path , hf_repo_id : str , trigger_word : Optional [str ]):
268
- readme_path = lora_dir / "README.md"
379
+ def handle_hf_readme (hf_repo_id : str , trigger_word : Optional [str ]):
380
+ readme_path = JOB_DIR / "README.md"
269
381
license_path = Path ("lora-license.md" )
270
382
shutil .copy (license_path , readme_path )
271
383
0 commit comments