|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from comfy.samplers import ksampler |
| 4 | +from pytorch_wavelets import DTCWTForward, DTCWTInverse, DWTForward, DWTInverse |
| 5 | + |
| 6 | +from .tensor_image_ops import ( |
| 7 | + BLENDING_MODES, |
| 8 | + Sharpen, |
| 9 | +) |
| 10 | +from .upscale import Upscale |
| 11 | +from .utils import fallback |
| 12 | +from .vae import VAEHelper |
| 13 | + |
| 14 | + |
| 15 | +class Config: |
| 16 | + _overridable_fields = { # noqa: RUF012 |
| 17 | + "blend_by_mode", |
| 18 | + "blend_mode", |
| 19 | + "denoised_wavelet_multiplier", |
| 20 | + "dtcwt_biort", |
| 21 | + "dtcwt_mode", |
| 22 | + "dtcwt_qshift", |
| 23 | + "dwt_flip_filters", |
| 24 | + "dwt_level", |
| 25 | + "dwt_mode", |
| 26 | + "dwt_wave", |
| 27 | + "fadeout_factor", |
| 28 | + "guidance_factor", |
| 29 | + "guidance_mode", |
| 30 | + "guidance_restart_s_noise", |
| 31 | + "guidance_restart", |
| 32 | + "guidance_steps", |
| 33 | + "iteration_override", |
| 34 | + "reference_wavelet_multiplier", |
| 35 | + "renoise_factor", |
| 36 | + "resample_mode", |
| 37 | + "rescale_increment", |
| 38 | + "scale_factor", |
| 39 | + "sharpen_gaussian_kernel_size", |
| 40 | + "sharpen_gaussian_sigma", |
| 41 | + "sharpen_mode", |
| 42 | + "sharpen_reference", |
| 43 | + "sharpen_strength", |
| 44 | + "sigma_offset", |
| 45 | + "vae_decode_kwargs", |
| 46 | + "vae_encode_kwargs", |
| 47 | + "vae_mode", |
| 48 | + } |
| 49 | + |
| 50 | + _dict_exclude_keys = { # noqa: RUF012 |
| 51 | + "as_dict", |
| 52 | + "blend_function", |
| 53 | + "dwt", |
| 54 | + "get_iteration_config", |
| 55 | + "idwt", |
| 56 | + "iteration_override", |
| 57 | + "sharpen", |
| 58 | + "upscale", |
| 59 | + "vae", |
| 60 | + } |
| 61 | + |
| 62 | + def __init__( |
| 63 | + self, |
| 64 | + device, |
| 65 | + dtype, |
| 66 | + latent_format, |
| 67 | + *, |
| 68 | + blend_mode="lerp", |
| 69 | + blend_by_mode="image", |
| 70 | + denoised_wavelet_multiplier=1.0, |
| 71 | + dtcwt_biort="near_sym_a", |
| 72 | + dtcwt_mode=False, |
| 73 | + dtcwt_qshift="qshift_a", |
| 74 | + dwt_flip_filters=False, |
| 75 | + dwt_level=1, |
| 76 | + dwt_mode="symmetric", |
| 77 | + dwt_wave="db4", |
| 78 | + fadeout_factor=0.0, |
| 79 | + guidance_factor=1.0, |
| 80 | + guidance_mode="image", |
| 81 | + guidance_restart_s_noise=1.0, |
| 82 | + guidance_restart=0, |
| 83 | + guidance_sampler=None, |
| 84 | + guidance_steps=5, |
| 85 | + iteration_override=None, |
| 86 | + iterations=1, |
| 87 | + reference_sampler=None, |
| 88 | + reference_wavelet_multiplier=1.0, |
| 89 | + renoise_factor=1.0, |
| 90 | + resample_mode="bicubic", |
| 91 | + rescale_increment=64, |
| 92 | + sampler=None, |
| 93 | + scale_factor=2.0, |
| 94 | + sharpen_gaussian_kernel_size=3, |
| 95 | + sharpen_gaussian_sigma=(0.1, 2.0), |
| 96 | + sharpen_mode="gaussian", |
| 97 | + sharpen_reference=True, |
| 98 | + sharpen_strength=1.0, |
| 99 | + sigma_offset=0, |
| 100 | + upscale_model=None, |
| 101 | + vae_decode_kwargs=None, |
| 102 | + vae_encode_kwargs=None, |
| 103 | + vae_mode="normal", |
| 104 | + vae=None, |
| 105 | + ): |
| 106 | + sampler = fallback( |
| 107 | + sampler, |
| 108 | + lambda: ksampler("euler"), |
| 109 | + default_is_fun=True, |
| 110 | + ) |
| 111 | + self.sigma_offset = sigma_offset |
| 112 | + self.fadeout_factor = fadeout_factor |
| 113 | + self.scale_factor = scale_factor |
| 114 | + self.guidance_factor = guidance_factor |
| 115 | + self.renoise_factor = renoise_factor |
| 116 | + self.iterations = iterations |
| 117 | + self.guidance_steps = guidance_steps |
| 118 | + self.guidance_mode = guidance_mode |
| 119 | + self.guidance_restart = guidance_restart |
| 120 | + self.guidance_restart_s_noise = guidance_restart_s_noise |
| 121 | + self.sampler = sampler |
| 122 | + self.guidance_sampler = fallback(guidance_sampler, sampler) |
| 123 | + self.reference_sampler = fallback(reference_sampler, sampler) |
| 124 | + self.vae = VAEHelper( |
| 125 | + vae_mode, |
| 126 | + latent_format, |
| 127 | + device=device, |
| 128 | + dtype=dtype, |
| 129 | + vae=vae, |
| 130 | + encode_kwargs=fallback(vae_encode_kwargs, {}), |
| 131 | + decode_kwargs=fallback(vae_decode_kwargs, {}), |
| 132 | + ) |
| 133 | + self.sharpen = Sharpen( |
| 134 | + mode=sharpen_mode, |
| 135 | + strength=sharpen_strength if sharpen_reference else 0, |
| 136 | + gaussian_kernel_size=sharpen_gaussian_kernel_size, |
| 137 | + gaussian_sigma=sharpen_gaussian_sigma, |
| 138 | + ) |
| 139 | + self.upscale = Upscale( |
| 140 | + resample_mode=resample_mode, |
| 141 | + rescale_increment=rescale_increment, |
| 142 | + upscale_model=upscale_model, |
| 143 | + ) |
| 144 | + self.dwt_mode = dwt_mode |
| 145 | + self.dwt_level = dwt_level |
| 146 | + self.dwt_wave = dwt_wave |
| 147 | + self.dtcwt_mode = dtcwt_mode |
| 148 | + self.dtcwt_biort = dtcwt_biort |
| 149 | + self.dtcwt_qshift = dtcwt_qshift |
| 150 | + if dtcwt_mode: |
| 151 | + self.dwt = DTCWTForward( |
| 152 | + J=dwt_level, |
| 153 | + mode=dwt_mode, |
| 154 | + biort=dtcwt_biort, |
| 155 | + qshift=dtcwt_qshift, |
| 156 | + ).to(device) |
| 157 | + self.idwt = DTCWTInverse( |
| 158 | + mode=dwt_mode, |
| 159 | + biort=dtcwt_biort, |
| 160 | + qshift=dtcwt_qshift, |
| 161 | + ).to(device) |
| 162 | + else: |
| 163 | + self.dwt = DWTForward(J=dwt_level, wave=dwt_wave, mode=dwt_mode).to(device) |
| 164 | + self.idwt = DWTInverse(wave=dwt_wave, mode=dwt_mode).to(device) |
| 165 | + self.dwt_flip_filters = dwt_flip_filters |
| 166 | + self.reference_wavelet_multiplier = reference_wavelet_multiplier |
| 167 | + self.denoised_wavelet_multiplier = denoised_wavelet_multiplier |
| 168 | + self.blend_mode = blend_mode |
| 169 | + if blend_by_mode not in {"image", "latent", "wavelet"}: |
| 170 | + raise ValueError("Bad blend_by_mode: must be one of image, latent, wavelet") |
| 171 | + self.blend_by_mode = blend_by_mode |
| 172 | + self.blend_function = BLENDING_MODES[blend_mode] |
| 173 | + self.iteration_override = {} |
| 174 | + if iteration_override is None or iteration_override == {}: |
| 175 | + return |
| 176 | + if not isinstance(iteration_override, dict): |
| 177 | + raise TypeError("Iteration override must be an object") |
| 178 | + # if isinstance(next(iter(iteration_override.values())), self.__class__): |
| 179 | + # self.iteration_Override = iteration_override |
| 180 | + # return |
| 181 | + selfdict = self.as_dict() |
| 182 | + overrides = self.iteration_override |
| 183 | + for k, v in iteration_override.items(): |
| 184 | + if not isinstance(k, (int, str)) or not isinstance(v, dict): |
| 185 | + raise TypeError( |
| 186 | + "Bad type for override item: key must be integer or string, value must be an object", |
| 187 | + ) |
| 188 | + okwargs = selfdict | { |
| 189 | + ok: ov for ok, ov in v.items() if ok in self._overridable_fields |
| 190 | + } |
| 191 | + overrides[k] = self.__class__(device, dtype, latent_format, **okwargs) |
| 192 | + |
| 193 | + def as_dict(self) -> dict: |
| 194 | + result = { |
| 195 | + k: getattr(self, k) |
| 196 | + for k in dir(self) |
| 197 | + if not k.startswith("_") and k not in self._dict_exclude_keys |
| 198 | + } |
| 199 | + result["vae_mode"] = self.vae.mode.name.lower() |
| 200 | + result["vae"] = self.vae.vae |
| 201 | + result["vae_encode_kwargs"] = self.vae.encode_kwargs |
| 202 | + result["vae_decode_kwargs"] = self.vae.decode_kwargs |
| 203 | + result["sharpen_reference"] = self.sharpen.strength != 0 |
| 204 | + result["sharpen_strength"] = self.sharpen.strength |
| 205 | + result["sharpen_gaussian_kernel_size"] = self.sharpen.gaussian_kernel_size |
| 206 | + result["sharpen_gaussian_sigma"] = self.sharpen.gaussian_sigma |
| 207 | + result["resample_mode"] = self.upscale.resample_mode |
| 208 | + result["rescale_increment"] = self.upscale.rescale_increment |
| 209 | + result["upscale_model"] = self.upscale.upscale_model |
| 210 | + return result |
| 211 | + |
| 212 | + def get_iteration_config(self, iteration): |
| 213 | + override = self.iteration_override.get(iteration) |
| 214 | + return override.get_iteration_config(iteration) if override else self |
0 commit comments