diff --git a/train_flux_lora_deepspeed.py b/train_flux_lora_deepspeed.py index 6290b50..f6a09a4 100644 --- a/train_flux_lora_deepspeed.py +++ b/train_flux_lora_deepspeed.py @@ -290,19 +290,20 @@ def main(): train_loss = 0.0 if not args.disable_sampling and global_step % args.sample_every == 0: - print(f"Sampling images for step {global_step}...") - sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device) - images = [] - for i, prompt in enumerate(args.sample_prompts): - result = sampler(prompt=prompt, - width=args.sample_width, - height=args.sample_height, - num_steps=args.sample_steps - ) - images.append(wandb.Image(result)) - print(f"Result for prompt #{i} is generated") - # result.save(f"{global_step}_prompt_{i}_res.png") - wandb.log({f"Results, step {global_step}": images}) + if accelerator.is_main_process: + print(f"Sampling images for step {global_step}...") + sampler = XFluxSampler(clip=clip, t5=t5, ae=vae, model=dit, device=accelerator.device) + images = [] + for i, prompt in enumerate(args.sample_prompts): + result = sampler(prompt=prompt, + width=args.sample_width, + height=args.sample_height, + num_steps=args.sample_steps + ) + images.append(wandb.Image(result)) + print(f"Result for prompt #{i} is generated") + # result.save(f"{global_step}_prompt_{i}_res.png") + wandb.log({f"Results, step {global_step}": images}) if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: