Skip to content

Commit e01506a

Browse files
Add weights&biases logging
1 parent e351d3a commit e01506a

File tree

3 files changed

+177
-22
lines changed

3 files changed

+177
-22
lines changed

cog.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ build:
5050
- "numpy==1.26.0"
5151
- "shortuuid==1.0.11"
5252
- "tokenizers==0.19"
53-
- "wandb==0.15.12"
53+
- "wandb==0.17.8"
5454
- "wavedrom==2.0.3.post3"
5555
- "Pygments==2.16.1"
5656
run:

train.py

Lines changed: 133 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,60 @@
2727
from toolkit.config import get_config
2828

2929
from caption import Captioner
30+
from wandb_client import WeightsAndBiasesClient
3031

32+
33+
JOB_NAME = "flux_train_replicate"
3134
WEIGHTS_PATH = Path("./FLUX.1-dev")
3235
INPUT_DIR = Path("input_images")
3336
OUTPUT_DIR = Path("output")
37+
JOB_DIR = OUTPUT_DIR / JOB_NAME
3438

3539

3640
class CustomSDTrainer(SDTrainer):
41+
def __init__(self, *args, **kwargs):
42+
super().__init__(*args, **kwargs)
43+
self.seen_samples = set()
44+
self.wandb: WeightsAndBiasesClient | None = None
45+
3746
def hook_train_loop(self, batch):
38-
# TODO: Add W&B logging, etc.
39-
return super().hook_train_loop(batch)
47+
loss_dict = super().hook_train_loop(batch)
48+
if self.wandb:
49+
self.wandb.log_loss(loss_dict, self.step_num)
50+
return loss_dict
51+
52+
def sample(self, step=None, is_first=False):
53+
super().sample(step=step, is_first=is_first)
54+
output_dir = JOB_DIR / "samples"
55+
all_samples = set([p.name for p in output_dir.glob("*.jpg")])
56+
new_samples = all_samples - self.seen_samples
57+
if self.wandb:
58+
image_paths = [output_dir / p for p in sorted(new_samples)]
59+
self.wandb.log_samples(image_paths, step)
60+
self.seen_samples = all_samples
61+
62+
def post_save_hook(self, save_path):
63+
super().post_save_hook(save_path)
64+
# final lora path
65+
lora_path = JOB_DIR / f"{JOB_NAME}.safetensors"
66+
if not lora_path.exists():
67+
# intermediate saved weights
68+
lora_path = sorted(JOB_DIR.glob("*.safetensors"))[-1]
69+
if self.wandb:
70+
print(f"Saving weights to W&B: {lora_path.name}")
71+
self.wandb.save_weights(lora_path)
4072

4173

4274
class CustomJob(BaseJob):
43-
def __init__(self, config: OrderedDict):
75+
def __init__(
76+
self, config: OrderedDict, wandb_client: WeightsAndBiasesClient | None
77+
):
4478
super().__init__(config)
4579
self.device = self.get_conf("device", "cpu")
4680
self.process_dict = {"custom_sd_trainer": CustomSDTrainer}
4781
self.load_processes(self.process_dict)
82+
for process in self.process:
83+
process.wandb = wandb_client
4884

4985
def run(self):
5086
super().run()
@@ -82,7 +118,7 @@ def train(
82118
),
83119
steps: int = Input(
84120
description="Number of training steps. Recommended range 500-4000",
85-
ge=10,
121+
ge=3,
86122
le=6000,
87123
default=1000,
88124
),
@@ -120,6 +156,36 @@ def train(
120156
description="Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face.",
121157
default=None,
122158
),
159+
wandb_api_key: Secret = Input(
160+
description="Weights and Biases API key, if you'd like to log training progress to W&B.",
161+
default=None,
162+
),
163+
wandb_project: str = Input(
164+
description="Weights and Biases project name. Only applicable if wandb_api_key is set.",
165+
default=JOB_NAME,
166+
),
167+
wandb_run: str = Input(
168+
description="Weights and Biases run name. Only applicable if wandb_api_key is set.",
169+
default=None,
170+
),
171+
wandb_entity: str = Input(
172+
description="Weights and Biases entity name. Only applicable if wandb_api_key is set.",
173+
default=None,
174+
),
175+
wandb_sample_interval: int = Input(
176+
description="Step interval for sampling output images that are logged to W&B. Only applicable if wandb_api_key is set.",
177+
default=100,
178+
ge=1,
179+
),
180+
wandb_sample_prompts: str = Input(
181+
description="Semicolon-separated list of prompts to use when logging samples to W&B. Only applicable if wandb_api_key is set.",
182+
default=None,
183+
),
184+
wandb_save_interval: int = Input(
185+
description="Step interval for saving intermediate LoRA weights to W&B. Only applicable if wandb_api_key is set.",
186+
default=100,
187+
ge=1,
188+
),
123189
skip_training_and_use_pretrained_hf_lora_url: str = Input(
124190
description="If you’d like to skip LoRA training altogether and instead create a Replicate model from a pre-trained LoRA that’s on HuggingFace, use this field with a HuggingFace download URL. For example, https://huggingface.co/fofr/flux-80s-cyberpunk/resolve/main/lora.safetensors.",
125191
default=None,
@@ -136,14 +202,42 @@ def train(
136202
if not input_images:
137203
raise ValueError("input_images must be provided")
138204

205+
sample_prompts = []
206+
if wandb_sample_prompts:
207+
sample_prompts = [p.strip() for p in wandb_sample_prompts.split(";")]
208+
209+
wandb_client = None
210+
if wandb_api_key:
211+
wandb_config = {
212+
"trigger_word": trigger_word,
213+
"autocaption": autocaption,
214+
"autocaption_prefix": autocaption_prefix,
215+
"autocaption_suffix": autocaption_suffix,
216+
"steps": steps,
217+
"learning_rate": learning_rate,
218+
"batch_size": batch_size,
219+
"resolution": resolution,
220+
"lora_rank": lora_rank,
221+
"caption_dropout_rate": caption_dropout_rate,
222+
"optimizer": optimizer,
223+
}
224+
wandb_client = WeightsAndBiasesClient(
225+
api_key=wandb_api_key.get_secret_value(),
226+
config=wandb_config,
227+
sample_prompts=sample_prompts,
228+
project=wandb_project,
229+
entity=wandb_entity,
230+
name=wandb_run,
231+
)
232+
139233
download_weights()
140234
extract_zip(input_images, INPUT_DIR)
141235

142236
train_config = OrderedDict(
143237
{
144238
"job": "custom_job",
145239
"config": {
146-
"name": "flux_train_replicate",
240+
"name": JOB_NAME,
147241
"process": [
148242
{
149243
"type": "custom_sd_trainer",
@@ -157,7 +251,9 @@ def train(
157251
},
158252
"save": {
159253
"dtype": "float16",
160-
"save_every": steps + 1,
254+
"save_every": wandb_save_interval
255+
if wandb_api_key
256+
else steps + 1,
161257
"max_step_saves_to_keep": 1,
162258
},
163259
"datasets": [
@@ -166,6 +262,7 @@ def train(
166262
"caption_ext": "txt",
167263
"caption_dropout_rate": caption_dropout_rate,
168264
"shuffle_tokens": False,
265+
# TODO: Do we need to cache to disk? It's faster not to.
169266
"cache_latents_to_disk": True,
170267
"resolution": [
171268
int(res) for res in resolution.split(",")
@@ -193,15 +290,17 @@ def train(
193290
},
194291
"sample": {
195292
"sampler": "flowmatch",
196-
"sample_every": steps + 1,
293+
"sample_every": wandb_sample_interval
294+
if wandb_api_key and sample_prompts
295+
else steps + 1,
197296
"width": 1024,
198297
"height": 1024,
199-
"prompts": [],
298+
"prompts": sample_prompts,
200299
"neg": "",
201300
"seed": 42,
202301
"walk_seed": True,
203-
"guidance_scale": 4,
204-
"sample_steps": 20,
302+
"guidance_scale": 3.5,
303+
"sample_steps": 28,
205304
},
206305
}
207306
],
@@ -222,39 +321,52 @@ def train(
222321
torch.cuda.empty_cache()
223322

224323
print("Starting train job")
225-
job = CustomJob(get_config(train_config, name=None))
324+
job = CustomJob(get_config(train_config, name=None), wandb_client)
226325
job.run()
326+
327+
if wandb_client:
328+
wandb_client.finish()
329+
227330
job.cleanup()
228331

229-
lora_dir = OUTPUT_DIR / "flux_train_replicate"
230-
lora_file = lora_dir / "flux_train_replicate.safetensors"
231-
lora_file.rename(lora_dir / "lora.safetensors")
332+
lora_file = JOB_DIR / f"{JOB_NAME}.safetensors"
333+
lora_file.rename(JOB_DIR / "lora.safetensors")
334+
335+
samples_dir = JOB_DIR / "samples"
336+
if samples_dir.exists():
337+
shutil.rmtree(samples_dir)
338+
339+
# Remove any intermediate lora paths
340+
lora_paths = JOB_DIR.glob("*.safetensors")
341+
for path in lora_paths:
342+
if path.name != "lora.safetensors":
343+
path.unlink()
232344

233345
# Optimizer is used to continue training, not needed in output
234-
optimizer_file = lora_dir / "optimizer.pt"
346+
optimizer_file = JOB_DIR / "optimizer.pt"
235347
if optimizer_file.exists():
236348
optimizer_file.unlink()
237349

238350
# Copy generated captions to the output tar
239351
# But do not upload publicly to HF
240-
captions_dir = lora_dir / "captions"
352+
captions_dir = JOB_DIR / "captions"
241353
captions_dir.mkdir(exist_ok=True)
242354
for caption_file in INPUT_DIR.glob("*.txt"):
243355
shutil.copy(caption_file, captions_dir)
244356

245-
os.system(f"tar -cvf {output_path} {lora_dir}")
357+
os.system(f"tar -cvf {output_path} {JOB_DIR}")
246358

247359
if hf_token is not None and hf_repo_id is not None:
248360
if captions_dir.exists():
249361
shutil.rmtree(captions_dir)
250362

251363
try:
252-
handle_hf_readme(lora_dir, hf_repo_id, trigger_word)
364+
handle_hf_readme(hf_repo_id, trigger_word)
253365
print(f"Uploading to Hugging Face: {hf_repo_id}")
254366
api = HfApi()
255367
api.upload_folder(
256368
repo_id=hf_repo_id,
257-
folder_path=lora_dir,
369+
folder_path=JOB_DIR,
258370
repo_type="model",
259371
use_auth_token=hf_token.get_secret_value(),
260372
)
@@ -264,8 +376,8 @@ def train(
264376
return TrainingOutput(weights=Path(output_path))
265377

266378

267-
def handle_hf_readme(lora_dir: Path, hf_repo_id: str, trigger_word: Optional[str]):
268-
readme_path = lora_dir / "README.md"
379+
def handle_hf_readme(hf_repo_id: str, trigger_word: Optional[str]):
380+
readme_path = JOB_DIR / "README.md"
269381
license_path = Path("lora-license.md")
270382
shutil.copy(license_path, readme_path)
271383

wandb_client.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pathlib import Path
2+
from typing import Any, Sequence
3+
import wandb
4+
from wandb.sdk.wandb_settings import Settings
5+
6+
7+
class WeightsAndBiasesClient:
8+
def __init__(
9+
self,
10+
api_key: str,
11+
project: str,
12+
config: dict,
13+
sample_prompts: list[str],
14+
entity: str | None,
15+
name: str | None,
16+
):
17+
self.api_key = api_key
18+
self.sample_prompts = sample_prompts
19+
wandb.login(key=self.api_key, verify=True)
20+
self.run = wandb.init(
21+
project=project,
22+
entity=entity,
23+
name=name,
24+
config=config,
25+
save_code=False,
26+
settings=Settings(_disable_machine_info=True),
27+
)
28+
29+
def log_loss(self, loss_dict: dict[str, Any], step: int | None):
30+
wandb.log(data=loss_dict, step=step)
31+
32+
def log_samples(self, image_paths: Sequence[Path], step: int | None):
33+
data = {
34+
f"samples/{prompt}": wandb.Image(str(path))
35+
for prompt, path in zip(self.sample_prompts, image_paths)
36+
}
37+
wandb.log(data=data, step=step)
38+
39+
def save_weights(self, lora_path: Path):
40+
wandb.save(lora_path)
41+
42+
def finish(self):
43+
wandb.finish()

0 commit comments

Comments
 (0)