Skip to content

Commit

Permalink
fix: wandb main proccess fix
Browse files Browse the repository at this point in the history
Update train_flux_lora_deepspeed.py
  • Loading branch information
Vovanm88 authored Sep 23, 2024
1 parent ac4e04b commit 4749542
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions train_flux_lora_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,19 +290,20 @@ def main():
train_loss = 0.0

if not args.disable_sampling and global_step % args.sample_every == 0:
print(f"Sampling images for step {global_step}...")
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device)
images = []
for i, prompt in enumerate(args.sample_prompts):
result = sampler(prompt=prompt,
width=args.sample_width,
height=args.sample_height,
num_steps=args.sample_steps
)
images.append(wandb.Image(result))
print(f"Result for prompt #{i} is generated")
# result.save(f"{global_step}_prompt_{i}_res.png")
wandb.log({f"Results, step {global_step}": images})
if accelerator.is_main_process:
print(f"Sampling images for step {global_step}...")
sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device)
images = []
for i, prompt in enumerate(args.sample_prompts):
result = sampler(prompt=prompt,
width=args.sample_width,
height=args.sample_height,
num_steps=args.sample_steps
)
images.append(wandb.Image(result))
print(f"Result for prompt #{i} is generated")
# result.save(f"{global_step}_prompt_{i}_res.png")
wandb.log({f"Results, step {global_step}": images})

if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
Expand Down

0 comments on commit 4749542

Please sign in to comment.