diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index b6d0399108d..be2c7d70fa3 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -47,6 +47,7 @@ pack, unpack, ) +from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType from invokeai.backend.patches.layer_patcher import LayerPatcher @@ -63,7 +64,7 @@ title="FLUX Denoise", tags=["image", "flux"], category="image", - version="4.1.0", + version="4.2.0", ) class FluxDenoiseInvocation(BaseInvocation): """Run denoising process with a FLUX transformer model.""" @@ -132,6 +133,12 @@ class FluxDenoiseInvocation(BaseInvocation): num_steps: int = InputField( default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50." ) + scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField( + default="euler", + description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. " + "'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.", + ui_choice_labels=FLUX_SCHEDULER_LABELS, + ) guidance: float = InputField( default=4.0, description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.", @@ -242,6 +249,12 @@ def _run_diffusion( shift=not is_schnell, ) + # Create scheduler if not using default euler + scheduler = None + if self.scheduler in FLUX_SCHEDULER_MAP: + scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler] + scheduler = scheduler_class(num_train_timesteps=1000) + # Clip the timesteps schedule based on denoising_start and denoising_end. timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end) @@ -426,6 +439,7 @@ def _run_diffusion( img_cond=img_cond, img_cond_seq=img_cond_seq, img_cond_seq_ids=img_cond_seq_ids, + scheduler=scheduler, ) x = unpack(x.float(), self.height, self.width) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 5c6443849c3..34839434382 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -1,7 +1,9 @@ +import inspect import math from typing import Callable import torch +from diffusers.schedulers.scheduling_utils import SchedulerMixin from tqdm import tqdm from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs @@ -35,24 +37,199 @@ def denoise( # extra img tokens (sequence-wise) - for Kontext conditioning img_cond_seq: torch.Tensor | None = None, img_cond_seq_ids: torch.Tensor | None = None, + # Optional scheduler for alternative sampling methods + scheduler: SchedulerMixin | None = None, ): - # step 0 is the initial state - total_steps = len(timesteps) - 1 - step_callback( - PipelineIntermediateState( - step=0, - order=1, - total_steps=total_steps, - timestep=int(timesteps[0]), - latents=img, - ), - ) + # Determine if we're using a diffusers scheduler or the built-in Euler method + use_scheduler = scheduler is not None + + if use_scheduler: + # Initialize scheduler with timesteps + # The timesteps list contains values in [0, 1] range (sigmas) + # Some schedulers (like Euler) support custom sigmas, others (like Heun) don't + set_timesteps_sig = inspect.signature(scheduler.set_timesteps) + if "sigmas" in set_timesteps_sig.parameters: + # Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule + scheduler.set_timesteps(sigmas=timesteps, device=img.device) + else: + # Scheduler doesn't support custom sigmas - use num_inference_steps + # The schedule will be computed by the scheduler itself + num_inference_steps = len(timesteps) - 1 + scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device) + + # For schedulers like Heun, the number of actual steps may differ + # (Heun doubles timesteps internally) + num_scheduler_steps = len(scheduler.timesteps) + # For user-facing step count, use the original number of denoising steps + total_steps = len(timesteps) - 1 + else: + total_steps = len(timesteps) - 1 + num_scheduler_steps = total_steps + # guidance_vec is ignored for schnell. guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) # Store original sequence length for slicing predictions original_seq_len = img.shape[1] + # Track the actual step for user-facing progress (accounts for Heun's double steps) + user_step = 0 + + if use_scheduler: + # Use diffusers scheduler for stepping + for step_index in tqdm(range(num_scheduler_steps)): + timestep = scheduler.timesteps[step_index] + # Convert scheduler timestep (0-1000) to normalized (0-1) for the model + t_curr = timestep.item() / scheduler.config.num_train_timesteps + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) + + # For Heun scheduler, track if we're in first or second order step + is_heun = hasattr(scheduler, "state_in_first_order") + in_first_order = scheduler.state_in_first_order if is_heun else True + + # Run ControlNet models + controlnet_residuals: list[ControlNetFluxOutput] = [] + for controlnet_extension in controlnet_extensions: + controlnet_residuals.append( + controlnet_extension.run_controlnet( + timestep_index=user_step, + total_num_timesteps=total_steps, + img=img, + img_ids=img_ids, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + ) + ) + + merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals) + + # Prepare input for model + img_input = img + img_input_ids = img_ids + + if img_cond is not None: + img_input = torch.cat((img_input, img_cond), dim=-1) + + if img_cond_seq is not None: + assert img_cond_seq_ids is not None + img_input = torch.cat((img_input, img_cond_seq), dim=1) + img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) + + pred = model( + img=img_input, + img_ids=img_input_ids, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + timestep_index=user_step, + total_num_timesteps=total_steps, + controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, + controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, + ip_adapter_extensions=pos_ip_adapter_extensions, + regional_prompting_extension=pos_regional_prompting_extension, + ) + + if img_cond_seq is not None: + pred = pred[:, :original_seq_len] + + # Get CFG scale for current user step + step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)] + + if not math.isclose(step_cfg_scale, 1.0): + if neg_regional_prompting_extension is None: + raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") + + neg_img_input = img + neg_img_input_ids = img_ids + + if img_cond is not None: + neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1) + + if img_cond_seq is not None: + neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1) + neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1) + + neg_pred = model( + img=neg_img_input, + img_ids=neg_img_input_ids, + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, + timesteps=t_vec, + guidance=guidance_vec, + timestep_index=user_step, + total_num_timesteps=total_steps, + controlnet_double_block_residuals=None, + controlnet_single_block_residuals=None, + ip_adapter_extensions=neg_ip_adapter_extensions, + regional_prompting_extension=neg_regional_prompting_extension, + ) + + if img_cond_seq is not None: + neg_pred = neg_pred[:, :original_seq_len] + pred = neg_pred + step_cfg_scale * (pred - neg_pred) + + # Use scheduler.step() for the update + step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img) + img = step_output.prev_sample + + # Get t_prev for inpainting (next sigma value) + if step_index + 1 < len(scheduler.sigmas): + t_prev = scheduler.sigmas[step_index + 1].item() + else: + t_prev = 0.0 + + if inpaint_extension is not None: + img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev) + + # For Heun, only increment user step after second-order step completes + if is_heun: + if not in_first_order: + # Second order step completed + user_step += 1 + # Only call step_callback if we haven't exceeded total_steps + if user_step <= total_steps: + preview_img = img - t_curr * pred + if inpaint_extension is not None: + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents( + preview_img, 0.0 + ) + step_callback( + PipelineIntermediateState( + step=user_step, + order=2, + total_steps=total_steps, + timestep=int(t_curr * 1000), + latents=preview_img, + ), + ) + else: + # For Euler, LCM and other first-order schedulers + user_step += 1 + # Only call step_callback if we haven't exceeded total_steps + # (LCM scheduler may have more internal steps than user-facing steps) + if user_step <= total_steps: + preview_img = img - t_curr * pred + if inpaint_extension is not None: + preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0) + step_callback( + PipelineIntermediateState( + step=user_step, + order=1, + total_steps=total_steps, + timestep=int(t_curr * 1000), + latents=preview_img, + ), + ) + + return img + + # Original Euler implementation (when scheduler is None) for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) diff --git a/invokeai/backend/flux/schedulers.py b/invokeai/backend/flux/schedulers.py new file mode 100644 index 00000000000..d428bc7ea57 --- /dev/null +++ b/invokeai/backend/flux/schedulers.py @@ -0,0 +1,40 @@ +"""Flux scheduler definitions and mapping. + +This module provides the scheduler types and mapping for Flux models, +supporting multiple Flow Matching schedulers from the diffusers library. +""" + +from typing import Literal, Type + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + FlowMatchHeunDiscreteScheduler, +) +from diffusers.schedulers.scheduling_utils import SchedulerMixin + +# Note: FlowMatchLCMScheduler may not be available in all diffusers versions +try: + from diffusers import FlowMatchLCMScheduler + + _HAS_LCM = True +except ImportError: + _HAS_LCM = False + +# Scheduler name literal type for type checking +FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"] + +# Human-readable labels for the UI +FLUX_SCHEDULER_LABELS: dict[str, str] = { + "euler": "Euler", + "heun": "Heun (2nd order)", + "lcm": "LCM", +} + +# Mapping from scheduler names to scheduler classes +FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = { + "euler": FlowMatchEulerDiscreteScheduler, + "heun": FlowMatchHeunDiscreteScheduler, +} + +if _HAS_LCM: + FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 93b97618fed..b5c5f7b51af 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -69,6 +69,9 @@ const slice = createSlice({ setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, + setFluxScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => { + state.fluxScheduler = action.payload; + }, setUpscaleScheduler: (state, action: PayloadAction) => { state.upscaleScheduler = action.payload; }, @@ -449,6 +452,7 @@ export const { setCfgRescaleMultiplier, setGuidance, setScheduler, + setFluxScheduler, setUpscaleScheduler, setUpscaleCfgScale, setSeed, @@ -588,6 +592,7 @@ export const selectModelSupportsOptimizedDenoising = createSelector( (model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base) ); export const selectScheduler = createParamsSelector((params) => params.scheduler); +export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler); export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis); export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis); export const selectSeed = createParamsSelector((params) => params.seed); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 0b714efee65..036e9e54e6a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -9,6 +9,7 @@ import { zParameterCLIPGEmbedModel, zParameterCLIPLEmbedModel, zParameterControlLoRAModel, + zParameterFluxScheduler, zParameterGuidance, zParameterImageDimension, zParameterMaskBlurMethod, @@ -596,6 +597,7 @@ export const zParamsState = z.object({ optimizedDenoisingEnabled: z.boolean(), iterations: z.number(), scheduler: zParameterScheduler, + fluxScheduler: zParameterFluxScheduler, upscaleScheduler: zParameterScheduler, upscaleCfgScale: zParameterCFGScale, seed: zParameterSeed, @@ -650,6 +652,7 @@ export const getInitialParamsState = (): ParamsState => ({ optimizedDenoisingEnabled: true, iterations: 1, scheduler: 'dpmpp_3m_k', + fluxScheduler: 'euler', upscaleScheduler: 'kdpm_2', upscaleCfgScale: 2, seed: 0, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 60d0754a147..438485481d7 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -64,6 +64,9 @@ export const zSchedulerField = z.enum([ 'tcd', ]); export type SchedulerField = z.infer; + +// Flux-specific scheduler options (Flow Matching schedulers) +export const zFluxSchedulerField = z.enum(['euler', 'heun', 'lcm']); // #endregion // #region Model-related schemas diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 558f8b2ffed..93d1ca9446f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -43,7 +43,7 @@ export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise { + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + const fluxScheduler = useAppSelector(selectFluxScheduler); + + const onChange = useCallback( + (v) => { + if (!isParameterFluxScheduler(v?.value)) { + return; + } + dispatch(setFluxScheduler(v.value)); + }, + [dispatch] + ); + + const value = useMemo(() => FLUX_SCHEDULER_OPTIONS.find((o) => o.value === fluxScheduler), [fluxScheduler]); + + return ( + + + {t('parameters.scheduler')} + + + + ); +}; + +export default memo(ParamFluxScheduler); diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 01ec9d97798..0eaa517a4a0 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,7 +1,7 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; import { buildZodTypeGuard } from 'common/util/zodUtils'; -import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common'; +import { zFluxSchedulerField, zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common'; import { z } from 'zod'; /** @@ -61,6 +61,11 @@ export const [zParameterScheduler, isParameterScheduler] = buildParameter(zSched export type ParameterScheduler = z.infer; // #endregion +// #region Flux Scheduler +export const [zParameterFluxScheduler, isParameterFluxScheduler] = buildParameter(zFluxSchedulerField); +export type ParameterFluxScheduler = z.infer; +// #endregion + // #region seed export const [zParameterSeed, isParameterSeed] = buildParameter(z.number().int().min(0).max(NUMPY_RAND_MAX)); export type ParameterSeed = z.infer; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx index 0b55f5db096..283242f1988 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion.tsx @@ -8,6 +8,7 @@ import { selectIsCogView4, selectIsFLUX, selectIsSD3, selectIsZImage } from 'fea import { LoRAList } from 'features/lora/components/LoRAList'; import LoRASelect from 'features/lora/components/LoRASelect'; import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale'; +import ParamFluxScheduler from 'features/parameters/components/Core/ParamFluxScheduler'; import ParamGuidance from 'features/parameters/components/Core/ParamGuidance'; import ParamScheduler from 'features/parameters/components/Core/ParamScheduler'; import ParamSteps from 'features/parameters/components/Core/ParamSteps'; @@ -68,6 +69,7 @@ export const GenerationSettingsAccordion = memo(() => { {!isFLUX && !isSD3 && !isCogView4 && !isZImage && } + {isFLUX && } {isFLUX && modelConfig && !isFluxFillMainModelModelConfig(modelConfig) && } {!isFLUX && } diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index fd3f699aa71..b2e4fccb3bf 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -8323,6 +8323,13 @@ export type components = { * @default 4 */ num_steps?: number; + /** + * Scheduler + * @description Scheduler (sampler) for the denoising process. 'euler' is fast and standard. 'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps. + * @default euler + * @enum {string} + */ + scheduler?: "euler" | "heun" | "lcm"; /** * Guidance * @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell. @@ -8491,6 +8498,13 @@ export type components = { * @default 4 */ num_steps?: number; + /** + * Scheduler + * @description Scheduler (sampler) for the denoising process. 'euler' is fast and standard. 'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps. + * @default euler + * @enum {string} + */ + scheduler?: "euler" | "heun" | "lcm"; /** * Guidance * @description The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.