|
14 | 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 | 15 | # See the License for the specific language governing permissions and
|
16 | 16 |
|
17 |
| -import contextlib |
18 | 17 | import gc
|
19 | 18 | import itertools
|
20 | 19 | import json
|
|
26 | 25 | from pathlib import Path
|
27 | 26 |
|
28 | 27 | import diffusers
|
29 |
| -import numpy as np |
30 | 28 | import torch
|
31 | 29 | import torch.nn.functional as F
|
32 | 30 | import torch.utils.checkpoint
|
|
36 | 34 | from diffusers import (
|
37 | 35 | AutoencoderKL,
|
38 | 36 | DDPMScheduler,
|
39 |
| - DPMSolverMultistepScheduler, |
40 | 37 | EDMEulerScheduler,
|
41 | 38 | EulerDiscreteScheduler,
|
42 | 39 | StableDiffusionXLPipeline,
|
@@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
|
78 | 75 | return scheduler_type
|
79 | 76 |
|
80 | 77 |
|
81 |
| -def log_validation( |
82 |
| - pipeline, |
83 |
| - args, |
84 |
| - accelerator, |
85 |
| - pipeline_args, |
86 |
| - epoch, |
87 |
| - is_final_validation=False, |
88 |
| -): |
89 |
| - logger.info( |
90 |
| - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
91 |
| - f" {args.validation_prompt}." |
92 |
| - ) |
93 |
| - |
94 |
| - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it |
95 |
| - scheduler_args = {} |
96 |
| - |
97 |
| - if not args.do_edm_style_training: |
98 |
| - if "variance_type" in pipeline.scheduler.config: |
99 |
| - variance_type = pipeline.scheduler.config.variance_type |
100 |
| - |
101 |
| - if variance_type in ["learned", "learned_range"]: |
102 |
| - variance_type = "fixed_small" |
103 |
| - |
104 |
| - scheduler_args["variance_type"] = variance_type |
105 |
| - |
106 |
| - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) |
107 |
| - |
108 |
| - pipeline = pipeline.to(accelerator.device) |
109 |
| - pipeline.set_progress_bar_config(disable=True) |
110 |
| - |
111 |
| - # run inference |
112 |
| - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None |
113 |
| - # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better |
114 |
| - # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 |
115 |
| - inference_ctx = ( |
116 |
| - contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() |
117 |
| - ) |
118 |
| - |
119 |
| - with inference_ctx: |
120 |
| - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] |
121 |
| - |
122 |
| - for tracker in accelerator.trackers: |
123 |
| - phase_name = "test" if is_final_validation else "validation" |
124 |
| - if tracker.name == "tensorboard": |
125 |
| - np_images = np.stack([np.asarray(img) for img in images]) |
126 |
| - tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") |
127 |
| - |
128 |
| - del pipeline |
129 |
| - torch.cuda.empty_cache() |
130 |
| - |
131 |
| - return images |
132 |
| - |
133 |
| - |
134 | 78 | def import_model_class_from_model_name_or_path(
|
135 | 79 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
136 | 80 | ):
|
@@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
1239 | 1183 | if global_step >= args.max_train_steps:
|
1240 | 1184 | break
|
1241 | 1185 |
|
1242 |
| - if accelerator.is_main_process: |
1243 |
| - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
1244 |
| - # create pipeline |
1245 |
| - if not args.train_text_encoder: |
1246 |
| - text_encoder_one = text_encoder_cls_one.from_pretrained( |
1247 |
| - args.pretrained_model_name_or_path, |
1248 |
| - subfolder="text_encoder", |
1249 |
| - revision=args.revision, |
1250 |
| - variant=args.variant, |
1251 |
| - ) |
1252 |
| - text_encoder_two = text_encoder_cls_two.from_pretrained( |
1253 |
| - args.pretrained_model_name_or_path, |
1254 |
| - subfolder="text_encoder_2", |
1255 |
| - revision=args.revision, |
1256 |
| - variant=args.variant, |
1257 |
| - ) |
1258 |
| - pipeline = StableDiffusionXLPipeline.from_pretrained( |
1259 |
| - args.pretrained_model_name_or_path, |
1260 |
| - vae=vae, |
1261 |
| - text_encoder=accelerator.unwrap_model(text_encoder_one), |
1262 |
| - text_encoder_2=accelerator.unwrap_model(text_encoder_two), |
1263 |
| - unet=accelerator.unwrap_model(unet), |
1264 |
| - revision=args.revision, |
1265 |
| - variant=args.variant, |
1266 |
| - torch_dtype=weight_dtype, |
1267 |
| - ) |
1268 |
| - pipeline_args = {"prompt": args.validation_prompt} |
1269 |
| - |
1270 |
| - images = log_validation( |
1271 |
| - pipeline, |
1272 |
| - args, |
1273 |
| - accelerator, |
1274 |
| - pipeline_args, |
1275 |
| - epoch, |
1276 |
| - ) |
1277 |
| - |
1278 | 1186 | # Save the lora layers
|
1279 | 1187 | accelerator.wait_for_everyone()
|
1280 | 1188 | if accelerator.is_main_process:
|
|
0 commit comments