Skip to content

Commit

Permalink
Add Weights & Biases logging
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Sep 4, 2024
1 parent b12ec33 commit e69b0f9
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ build:
- "numpy==1.26.0"
- "shortuuid==1.0.11"
- "tokenizers==0.19"
- "wandb==0.15.12"
- "wandb==0.17.8"
- "wavedrom==2.0.3.post3"
- "Pygments==2.16.1"
run:
Expand Down
154 changes: 133 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,60 @@
from toolkit.config import get_config

from caption import Captioner
from wandb_client import WeightsAndBiasesClient


JOB_NAME = "flux_train_replicate"
WEIGHTS_PATH = Path("./FLUX.1-dev")
INPUT_DIR = Path("input_images")
OUTPUT_DIR = Path("output")
JOB_DIR = OUTPUT_DIR / JOB_NAME


class CustomSDTrainer(SDTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.seen_samples = set()
self.wandb: WeightsAndBiasesClient | None = None

def hook_train_loop(self, batch):
# TODO: Add W&B logging, etc.
return super().hook_train_loop(batch)
loss_dict = super().hook_train_loop(batch)
if self.wandb:
self.wandb.log_loss(loss_dict, self.step_num)
return loss_dict

def sample(self, step=None, is_first=False):
super().sample(step=step, is_first=is_first)
output_dir = JOB_DIR / "samples"
all_samples = set([p.name for p in output_dir.glob("*.jpg")])
new_samples = all_samples - self.seen_samples
if self.wandb:
image_paths = [output_dir / p for p in sorted(new_samples)]
self.wandb.log_samples(image_paths, step)
self.seen_samples = all_samples

def post_save_hook(self, save_path):
super().post_save_hook(save_path)
# final lora path
lora_path = JOB_DIR / f"{JOB_NAME}.safetensors"
if not lora_path.exists():
# intermediate saved weights
lora_path = sorted(JOB_DIR.glob("*.safetensors"))[-1]
if self.wandb:
print(f"Saving weights to W&B: {lora_path.name}")
self.wandb.save_weights(lora_path)


class CustomJob(BaseJob):
def __init__(self, config: OrderedDict):
def __init__(
self, config: OrderedDict, wandb_client: WeightsAndBiasesClient | None
):
super().__init__(config)
self.device = self.get_conf("device", "cpu")
self.process_dict = {"custom_sd_trainer": CustomSDTrainer}
self.load_processes(self.process_dict)
for process in self.process:
process.wandb = wandb_client

def run(self):
super().run()
Expand Down Expand Up @@ -82,7 +118,7 @@ def train(
),
steps: int = Input(
description="Number of training steps. Recommended range 500-4000",
ge=10,
ge=3,
le=6000,
default=1000,
),
Expand Down Expand Up @@ -120,6 +156,36 @@ def train(
description="Hugging Face token, if you'd like to upload the trained LoRA to Hugging Face.",
default=None,
),
wandb_api_key: Secret = Input(
description="Weights and Biases API key, if you'd like to log training progress to W&B.",
default=None,
),
wandb_project: str = Input(
description="Weights and Biases project name. Only applicable if wandb_api_key is set.",
default=JOB_NAME,
),
wandb_run: str = Input(
description="Weights and Biases run name. Only applicable if wandb_api_key is set.",
default=None,
),
wandb_entity: str = Input(
description="Weights and Biases entity name. Only applicable if wandb_api_key is set.",
default=None,
),
wandb_sample_interval: int = Input(
description="Step interval for sampling output images that are logged to W&B. Only applicable if wandb_api_key is set.",
default=100,
ge=1,
),
wandb_sample_prompts: str = Input(
description="Semicolon-separated list of prompts to use when logging samples to W&B. Only applicable if wandb_api_key is set.",
default=None,
),
wandb_save_interval: int = Input(
description="Step interval for saving intermediate LoRA weights to W&B. Only applicable if wandb_api_key is set.",
default=100,
ge=1,
),
skip_training_and_use_pretrained_hf_lora_url: str = Input(
description="If you’d like to skip LoRA training altogether and instead create a Replicate model from a pre-trained LoRA that’s on HuggingFace, use this field with a HuggingFace download URL. For example, https://huggingface.co/fofr/flux-80s-cyberpunk/resolve/main/lora.safetensors.",
default=None,
Expand All @@ -136,14 +202,42 @@ def train(
if not input_images:
raise ValueError("input_images must be provided")

sample_prompts = []
if wandb_sample_prompts:
sample_prompts = [p.strip() for p in wandb_sample_prompts.split(";")]

wandb_client = None
if wandb_api_key:
wandb_config = {
"trigger_word": trigger_word,
"autocaption": autocaption,
"autocaption_prefix": autocaption_prefix,
"autocaption_suffix": autocaption_suffix,
"steps": steps,
"learning_rate": learning_rate,
"batch_size": batch_size,
"resolution": resolution,
"lora_rank": lora_rank,
"caption_dropout_rate": caption_dropout_rate,
"optimizer": optimizer,
}
wandb_client = WeightsAndBiasesClient(
api_key=wandb_api_key.get_secret_value(),
config=wandb_config,
sample_prompts=sample_prompts,
project=wandb_project,
entity=wandb_entity,
name=wandb_run,
)

download_weights()
extract_zip(input_images, INPUT_DIR)

train_config = OrderedDict(
{
"job": "custom_job",
"config": {
"name": "flux_train_replicate",
"name": JOB_NAME,
"process": [
{
"type": "custom_sd_trainer",
Expand All @@ -157,7 +251,9 @@ def train(
},
"save": {
"dtype": "float16",
"save_every": steps + 1,
"save_every": wandb_save_interval
if wandb_api_key
else steps + 1,
"max_step_saves_to_keep": 1,
},
"datasets": [
Expand All @@ -166,6 +262,7 @@ def train(
"caption_ext": "txt",
"caption_dropout_rate": caption_dropout_rate,
"shuffle_tokens": False,
# TODO: Do we need to cache to disk? It's faster not to.
"cache_latents_to_disk": True,
"resolution": [
int(res) for res in resolution.split(",")
Expand Down Expand Up @@ -193,15 +290,17 @@ def train(
},
"sample": {
"sampler": "flowmatch",
"sample_every": steps + 1,
"sample_every": wandb_sample_interval
if wandb_api_key and sample_prompts
else steps + 1,
"width": 1024,
"height": 1024,
"prompts": [],
"prompts": sample_prompts,
"neg": "",
"seed": 42,
"walk_seed": True,
"guidance_scale": 4,
"sample_steps": 20,
"guidance_scale": 3.5,
"sample_steps": 28,
},
}
],
Expand All @@ -222,34 +321,47 @@ def train(
torch.cuda.empty_cache()

print("Starting train job")
job = CustomJob(get_config(train_config, name=None))
job = CustomJob(get_config(train_config, name=None), wandb_client)
job.run()

if wandb_client:
wandb_client.finish()

job.cleanup()

lora_dir = OUTPUT_DIR / "flux_train_replicate"
lora_file = lora_dir / "flux_train_replicate.safetensors"
lora_file.rename(lora_dir / "lora.safetensors")
lora_file = JOB_DIR / f"{JOB_NAME}.safetensors"
lora_file.rename(JOB_DIR / "lora.safetensors")

samples_dir = JOB_DIR / "samples"
if samples_dir.exists():
shutil.rmtree(samples_dir)

# Remove any intermediate lora paths
lora_paths = JOB_DIR.glob("*.safetensors")
for path in lora_paths:
if path.name != "lora.safetensors":
path.unlink()

# Optimizer is used to continue training, not needed in output
optimizer_file = lora_dir / "optimizer.pt"
optimizer_file = JOB_DIR / "optimizer.pt"
if optimizer_file.exists():
optimizer_file.unlink()

# Copy generated captions to the output tar
# But do not upload publicly to HF
captions_dir = lora_dir / "captions"
captions_dir = JOB_DIR / "captions"
captions_dir.mkdir(exist_ok=True)
for caption_file in INPUT_DIR.glob("*.txt"):
shutil.copy(caption_file, captions_dir)

os.system(f"tar -cvf {output_path} {lora_dir}")
os.system(f"tar -cvf {output_path} {JOB_DIR}")

if hf_token is not None and hf_repo_id is not None:
if captions_dir.exists():
shutil.rmtree(captions_dir)

try:
handle_hf_readme(lora_dir, hf_repo_id, trigger_word)
handle_hf_readme(hf_repo_id, trigger_word)
print(f"Uploading to Hugging Face: {hf_repo_id}")
api = HfApi()

Expand All @@ -264,7 +376,7 @@ def train(

api.upload_folder(
repo_id=hf_repo_id,
folder_path=lora_dir,
folder_path=JOB_DIR,
repo_type="model",
use_auth_token=hf_token.get_secret_value(),
)
Expand All @@ -274,8 +386,8 @@ def train(
return TrainingOutput(weights=Path(output_path))


def handle_hf_readme(lora_dir: Path, hf_repo_id: str, trigger_word: Optional[str]):
readme_path = lora_dir / "README.md"
def handle_hf_readme(hf_repo_id: str, trigger_word: Optional[str]):
readme_path = JOB_DIR / "README.md"
license_path = Path("lora-license.md")
shutil.copy(license_path, readme_path)

Expand Down
43 changes: 43 additions & 0 deletions wandb_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path
from typing import Any, Sequence
import wandb
from wandb.sdk.wandb_settings import Settings


class WeightsAndBiasesClient:
def __init__(
self,
api_key: str,
project: str,
config: dict,
sample_prompts: list[str],
entity: str | None,
name: str | None,
):
self.api_key = api_key
self.sample_prompts = sample_prompts
wandb.login(key=self.api_key, verify=True)
self.run = wandb.init(
project=project,
entity=entity,
name=name,
config=config,
save_code=False,
settings=Settings(_disable_machine_info=True),
)

def log_loss(self, loss_dict: dict[str, Any], step: int | None):
wandb.log(data=loss_dict, step=step)

def log_samples(self, image_paths: Sequence[Path], step: int | None):
data = {
f"samples/{prompt}": wandb.Image(str(path))
for prompt, path in zip(self.sample_prompts, image_paths)
}
wandb.log(data=data, step=step)

def save_weights(self, lora_path: Path):
wandb.save(lora_path)

def finish(self):
wandb.finish()

0 comments on commit e69b0f9

Please sign in to comment.