Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
pack,
unpack,
)
from invokeai.backend.flux.schedulers import FLUX_SCHEDULER_LABELS, FLUX_SCHEDULER_MAP, FLUX_SCHEDULER_NAME_VALUES
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.taxonomy import BaseModelType, FluxVariantType, ModelFormat, ModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
Expand All @@ -63,7 +64,7 @@
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.1.0",
version="4.2.0",
)
class FluxDenoiseInvocation(BaseInvocation):
"""Run denoising process with a FLUX transformer model."""
Expand Down Expand Up @@ -132,6 +133,12 @@ class FluxDenoiseInvocation(BaseInvocation):
num_steps: int = InputField(
default=4, description="Number of diffusion steps. Recommended values are schnell: 4, dev: 50."
)
scheduler: FLUX_SCHEDULER_NAME_VALUES = InputField(
default="euler",
description="Scheduler (sampler) for the denoising process. 'euler' is fast and standard. "
"'heun' is 2nd-order (better quality, 2x slower). 'lcm' is optimized for few steps.",
ui_choice_labels=FLUX_SCHEDULER_LABELS,
)
guidance: float = InputField(
default=4.0,
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
Expand Down Expand Up @@ -242,6 +249,12 @@ def _run_diffusion(
shift=not is_schnell,
)

# Create scheduler if not using default euler
scheduler = None
if self.scheduler in FLUX_SCHEDULER_MAP:
scheduler_class = FLUX_SCHEDULER_MAP[self.scheduler]
scheduler = scheduler_class(num_train_timesteps=1000)

# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)

Expand Down Expand Up @@ -426,6 +439,7 @@ def _run_diffusion(
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
scheduler=scheduler,
)

x = unpack(x.float(), self.height, self.width)
Expand Down
311 changes: 229 additions & 82 deletions invokeai/app/invocations/z_image_denoise.py

Large diffs are not rendered by default.

199 changes: 188 additions & 11 deletions invokeai/backend/flux/denoise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
import math
from typing import Callable

import torch
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from tqdm import tqdm

from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
Expand Down Expand Up @@ -35,24 +37,199 @@ def denoise(
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
# Optional scheduler for alternative sampling methods
scheduler: SchedulerMixin | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=img,
),
)
# Determine if we're using a diffusers scheduler or the built-in Euler method
use_scheduler = scheduler is not None

if use_scheduler:
# Initialize scheduler with timesteps
# The timesteps list contains values in [0, 1] range (sigmas)
# Some schedulers (like Euler) support custom sigmas, others (like Heun) don't
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
if "sigmas" in set_timesteps_sig.parameters:
# Scheduler supports custom sigmas - use InvokeAI's time-shifted schedule
scheduler.set_timesteps(sigmas=timesteps, device=img.device)
else:
# Scheduler doesn't support custom sigmas - use num_inference_steps
# The schedule will be computed by the scheduler itself
num_inference_steps = len(timesteps) - 1
scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=img.device)

# For schedulers like Heun, the number of actual steps may differ
# (Heun doubles timesteps internally)
num_scheduler_steps = len(scheduler.timesteps)
# For user-facing step count, use the original number of denoising steps
total_steps = len(timesteps) - 1
else:
total_steps = len(timesteps) - 1
num_scheduler_steps = total_steps

# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]

# Track the actual step for user-facing progress (accounts for Heun's double steps)
user_step = 0

if use_scheduler:
# Use diffusers scheduler for stepping
for step_index in tqdm(range(num_scheduler_steps)):
timestep = scheduler.timesteps[step_index]
# Convert scheduler timestep (0-1000) to normalized (0-1) for the model
t_curr = timestep.item() / scheduler.config.num_train_timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)

# For Heun scheduler, track if we're in first or second order step
is_heun = hasattr(scheduler, "state_in_first_order")
in_first_order = scheduler.state_in_first_order if is_heun else True

# Run ControlNet models
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=user_step,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
)

merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)

# Prepare input for model
img_input = img
img_input_ids = img_ids

if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)

if img_cond_seq is not None:
assert img_cond_seq_ids is not None
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)

pred = model(
img=img_input,
img_ids=img_input_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=user_step,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)

if img_cond_seq is not None:
pred = pred[:, :original_seq_len]

# Get CFG scale for current user step
step_cfg_scale = cfg_scale[min(user_step, len(cfg_scale) - 1)]

if not math.isclose(step_cfg_scale, 1.0):
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")

neg_img_input = img
neg_img_input_ids = img_ids

if img_cond is not None:
neg_img_input = torch.cat((neg_img_input, img_cond), dim=-1)

if img_cond_seq is not None:
neg_img_input = torch.cat((neg_img_input, img_cond_seq), dim=1)
neg_img_input_ids = torch.cat((neg_img_input_ids, img_cond_seq_ids), dim=1)

neg_pred = model(
img=neg_img_input,
img_ids=neg_img_input_ids,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=user_step,
total_num_timesteps=total_steps,
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
regional_prompting_extension=neg_regional_prompting_extension,
)

if img_cond_seq is not None:
neg_pred = neg_pred[:, :original_seq_len]
pred = neg_pred + step_cfg_scale * (pred - neg_pred)

# Use scheduler.step() for the update
step_output = scheduler.step(model_output=pred, timestep=timestep, sample=img)
img = step_output.prev_sample

# Get t_prev for inpainting (next sigma value)
if step_index + 1 < len(scheduler.sigmas):
t_prev = scheduler.sigmas[step_index + 1].item()
else:
t_prev = 0.0

if inpaint_extension is not None:
img = inpaint_extension.merge_intermediate_latents_with_init_latents(img, t_prev)

# For Heun, only increment user step after second-order step completes
if is_heun:
if not in_first_order:
# Second order step completed
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
if user_step <= total_steps:
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(
preview_img, 0.0
)
step_callback(
PipelineIntermediateState(
step=user_step,
order=2,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=preview_img,
),
)
else:
# For Euler, LCM and other first-order schedulers
user_step += 1
# Only call step_callback if we haven't exceeded total_steps
# (LCM scheduler may have more internal steps than user-facing steps)
if user_step <= total_steps:
preview_img = img - t_curr * pred
if inpaint_extension is not None:
preview_img = inpaint_extension.merge_intermediate_latents_with_init_latents(preview_img, 0.0)
step_callback(
PipelineIntermediateState(
step=user_step,
order=1,
total_steps=total_steps,
timestep=int(t_curr * 1000),
latents=preview_img,
),
)

return img

# Original Euler implementation (when scheduler is None)
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)

Expand Down
62 changes: 62 additions & 0 deletions invokeai/backend/flux/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Flow Matching scheduler definitions and mapping.

This module provides the scheduler types and mapping for Flow Matching models
(Flux and Z-Image), supporting multiple schedulers from the diffusers library.
"""

from typing import Literal, Type

from diffusers import (
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
)
from diffusers.schedulers.scheduling_utils import SchedulerMixin

# Note: FlowMatchLCMScheduler may not be available in all diffusers versions
try:
from diffusers import FlowMatchLCMScheduler

_HAS_LCM = True
except ImportError:
_HAS_LCM = False

# Scheduler name literal type for type checking
FLUX_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]

# Human-readable labels for the UI
FLUX_SCHEDULER_LABELS: dict[str, str] = {
"euler": "Euler",
"heun": "Heun (2nd order)",
"lcm": "LCM",
}

# Mapping from scheduler names to scheduler classes
FLUX_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
"euler": FlowMatchEulerDiscreteScheduler,
"heun": FlowMatchHeunDiscreteScheduler,
}

if _HAS_LCM:
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler


# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
# can be used for experimentation.
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]

# Human-readable labels for the UI
ZIMAGE_SCHEDULER_LABELS: dict[str, str] = {
"euler": "Euler",
"heun": "Heun (2nd order)",
"lcm": "LCM",
}

# Mapping from scheduler names to scheduler classes (same as Flux)
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
"euler": FlowMatchEulerDiscreteScheduler,
"heun": FlowMatchHeunDiscreteScheduler,
}

if _HAS_LCM:
ZIMAGE_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ const slice = createSlice({
setScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
state.scheduler = action.payload;
},
setFluxScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
state.fluxScheduler = action.payload;
},
setZImageScheduler: (state, action: PayloadAction<'euler' | 'heun' | 'lcm'>) => {
state.zImageScheduler = action.payload;
},
setUpscaleScheduler: (state, action: PayloadAction<ParameterScheduler>) => {
state.upscaleScheduler = action.payload;
},
Expand Down Expand Up @@ -449,6 +455,8 @@ export const {
setCfgRescaleMultiplier,
setGuidance,
setScheduler,
setFluxScheduler,
setZImageScheduler,
setUpscaleScheduler,
setUpscaleCfgScale,
setSeed,
Expand Down Expand Up @@ -588,6 +596,8 @@ export const selectModelSupportsOptimizedDenoising = createSelector(
(model) => !!model && SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS.includes(model.base)
);
export const selectScheduler = createParamsSelector((params) => params.scheduler);
export const selectFluxScheduler = createParamsSelector((params) => params.fluxScheduler);
export const selectZImageScheduler = createParamsSelector((params) => params.zImageScheduler);
export const selectSeamlessXAxis = createParamsSelector((params) => params.seamlessXAxis);
export const selectSeamlessYAxis = createParamsSelector((params) => params.seamlessYAxis);
export const selectSeed = createParamsSelector((params) => params.seed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
zParameterCLIPGEmbedModel,
zParameterCLIPLEmbedModel,
zParameterControlLoRAModel,
zParameterFluxScheduler,
zParameterGuidance,
zParameterImageDimension,
zParameterMaskBlurMethod,
Expand All @@ -23,6 +24,7 @@ import {
zParameterStrength,
zParameterT5EncoderModel,
zParameterVAEModel,
zParameterZImageScheduler,
} from 'features/parameters/types/parameterSchemas';
import type { JsonObject } from 'type-fest';
import { z } from 'zod';
Expand Down Expand Up @@ -596,6 +598,8 @@ export const zParamsState = z.object({
optimizedDenoisingEnabled: z.boolean(),
iterations: z.number(),
scheduler: zParameterScheduler,
fluxScheduler: zParameterFluxScheduler,
zImageScheduler: zParameterZImageScheduler,
upscaleScheduler: zParameterScheduler,
upscaleCfgScale: zParameterCFGScale,
seed: zParameterSeed,
Expand Down Expand Up @@ -650,6 +654,8 @@ export const getInitialParamsState = (): ParamsState => ({
optimizedDenoisingEnabled: true,
iterations: 1,
scheduler: 'dpmpp_3m_k',
fluxScheduler: 'euler',
zImageScheduler: 'euler',
upscaleScheduler: 'kdpm_2',
upscaleCfgScale: 2,
seed: 0,
Expand Down
Loading