Skip to content

Commit e10241b

Browse files
try diffusers update (#814)
1 parent c271eb7 commit e10241b

File tree

3 files changed

+12
-98
lines changed

3 files changed

+12
-98
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ trl==0.12.0
2929
tiktoken==0.6.0
3030
transformers==4.46.2
3131
accelerate==1.1.1
32-
diffusers==0.27.2
32+
diffusers==0.31.0
3333
bitsandbytes==0.44.1
3434
# extras
3535
rouge_score==0.1.2

src/autotrain/trainers/dreambooth/train.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
DPMSolverMultistepScheduler,
3838
UNet2DConditionModel,
3939
)
40-
from diffusers.loaders import LoraLoaderMixin
40+
from diffusers.loaders import StableDiffusionLoraLoaderMixin
4141
from diffusers.optimization import get_scheduler
4242
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
4343
from diffusers.utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft, is_wandb_available
@@ -63,6 +63,7 @@ def log_validation(
6363
accelerator,
6464
pipeline_args,
6565
epoch,
66+
torch_dtype, # Add torch_dtype parameter
6667
is_final_validation=False,
6768
):
6869
logger.info(
@@ -82,7 +83,7 @@ def log_validation(
8283

8384
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
8485

85-
pipeline = pipeline.to(accelerator.device)
86+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) # Use torch_dtype
8687
pipeline.set_progress_bar_config(disable=True)
8788

8889
# run inference
@@ -340,6 +341,10 @@ def main(args):
340341
project_config=accelerator_project_config,
341342
)
342343

344+
# Add MPS support check
345+
if torch.backends.mps.is_available():
346+
accelerator.native_amp = False
347+
343348
if args.report_to == "wandb":
344349
if not is_wandb_available():
345350
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
@@ -545,7 +550,7 @@ def save_model_hook(models, weights, output_dir):
545550
# make sure to pop weight so that corresponding model is not saved again
546551
weights.pop()
547552

548-
LoraLoaderMixin.save_lora_weights(
553+
StableDiffusionLoraLoaderMixin.save_lora_weights(
549554
output_dir,
550555
unet_lora_layers=unet_lora_layers_to_save,
551556
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
@@ -565,7 +570,7 @@ def load_model_hook(models, input_dir):
565570
else:
566571
raise ValueError(f"unexpected save model: {model.__class__}")
567572

568-
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
573+
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
569574

570575
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
571576
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
@@ -948,6 +953,7 @@ def compute_text_embeddings(prompt):
948953
accelerator,
949954
pipeline_args,
950955
epoch,
956+
torch_dtype=weight_dtype,
951957
)
952958

953959
# Save the lora layers
@@ -964,7 +970,7 @@ def compute_text_embeddings(prompt):
964970
else:
965971
text_encoder_state_dict = None
966972

967-
LoraLoaderMixin.save_lora_weights(
973+
StableDiffusionLoraLoaderMixin.save_lora_weights(
968974
save_directory=args.output_dir,
969975
unet_lora_layers=unet_lora_state_dict,
970976
text_encoder_lora_layers=text_encoder_state_dict,

src/autotrain/trainers/dreambooth/train_xl.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616

17-
import contextlib
1817
import gc
1918
import itertools
2019
import json
@@ -26,7 +25,6 @@
2625
from pathlib import Path
2726

2827
import diffusers
29-
import numpy as np
3028
import torch
3129
import torch.nn.functional as F
3230
import torch.utils.checkpoint
@@ -36,7 +34,6 @@
3634
from diffusers import (
3735
AutoencoderKL,
3836
DDPMScheduler,
39-
DPMSolverMultistepScheduler,
4037
EDMEulerScheduler,
4138
EulerDiscreteScheduler,
4239
StableDiffusionXLPipeline,
@@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
7875
return scheduler_type
7976

8077

81-
def log_validation(
82-
pipeline,
83-
args,
84-
accelerator,
85-
pipeline_args,
86-
epoch,
87-
is_final_validation=False,
88-
):
89-
logger.info(
90-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
91-
f" {args.validation_prompt}."
92-
)
93-
94-
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
95-
scheduler_args = {}
96-
97-
if not args.do_edm_style_training:
98-
if "variance_type" in pipeline.scheduler.config:
99-
variance_type = pipeline.scheduler.config.variance_type
100-
101-
if variance_type in ["learned", "learned_range"]:
102-
variance_type = "fixed_small"
103-
104-
scheduler_args["variance_type"] = variance_type
105-
106-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
107-
108-
pipeline = pipeline.to(accelerator.device)
109-
pipeline.set_progress_bar_config(disable=True)
110-
111-
# run inference
112-
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
113-
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
114-
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
115-
inference_ctx = (
116-
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
117-
)
118-
119-
with inference_ctx:
120-
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
121-
122-
for tracker in accelerator.trackers:
123-
phase_name = "test" if is_final_validation else "validation"
124-
if tracker.name == "tensorboard":
125-
np_images = np.stack([np.asarray(img) for img in images])
126-
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
127-
128-
del pipeline
129-
torch.cuda.empty_cache()
130-
131-
return images
132-
133-
13478
def import_model_class_from_model_name_or_path(
13579
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
13680
):
@@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12391183
if global_step >= args.max_train_steps:
12401184
break
12411185

1242-
if accelerator.is_main_process:
1243-
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1244-
# create pipeline
1245-
if not args.train_text_encoder:
1246-
text_encoder_one = text_encoder_cls_one.from_pretrained(
1247-
args.pretrained_model_name_or_path,
1248-
subfolder="text_encoder",
1249-
revision=args.revision,
1250-
variant=args.variant,
1251-
)
1252-
text_encoder_two = text_encoder_cls_two.from_pretrained(
1253-
args.pretrained_model_name_or_path,
1254-
subfolder="text_encoder_2",
1255-
revision=args.revision,
1256-
variant=args.variant,
1257-
)
1258-
pipeline = StableDiffusionXLPipeline.from_pretrained(
1259-
args.pretrained_model_name_or_path,
1260-
vae=vae,
1261-
text_encoder=accelerator.unwrap_model(text_encoder_one),
1262-
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1263-
unet=accelerator.unwrap_model(unet),
1264-
revision=args.revision,
1265-
variant=args.variant,
1266-
torch_dtype=weight_dtype,
1267-
)
1268-
pipeline_args = {"prompt": args.validation_prompt}
1269-
1270-
images = log_validation(
1271-
pipeline,
1272-
args,
1273-
accelerator,
1274-
pipeline_args,
1275-
epoch,
1276-
)
1277-
12781186
# Save the lora layers
12791187
accelerator.wait_for_everyone()
12801188
if accelerator.is_main_process:

0 commit comments

Comments
 (0)