Skip to content

add PAG support #7944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
a6a0429
first draft
yiyixuxu May 13, 2024
3605df9
refactor
yiyixuxu May 14, 2024
f94376c
update
yiyixuxu May 14, 2024
54c3fd6
up
yiyixuxu May 14, 2024
f571430
style
yiyixuxu May 14, 2024
91d0a5b
style
May 14, 2024
01585ab
update
yiyixuxu May 14, 2024
03bdbcd
inpaint + controlnet
yiyixuxu May 14, 2024
b662207
Merge branch 'pag' of github.com:huggingface/diffusers into pag
yiyixuxu May 14, 2024
1fb2c33
style
May 14, 2024
219f4b9
up
May 14, 2024
5641cb4
Update src/diffusers/pipelines/pag_utils.py
yiyixuxu May 14, 2024
8950e80
fix controlnet
KKIEEK May 15, 2024
4cc0b8b
fix compatability issue between PAG and IP-adapter (#8379)
sunovivid Jun 5, 2024
5cbf226
up
yiyixuxu Jun 6, 2024
58804a0
refactor ip-adapter
yiyixuxu Jun 8, 2024
7bc9229
style
Jun 8, 2024
e09e079
Merge branch 'main' into pag
yiyixuxu Jun 8, 2024
1fa54df
style
Jun 8, 2024
ba366f0
u[
Jun 9, 2024
d5a6761
up
Jun 10, 2024
854b70e
fix
Jun 10, 2024
9e4c1b6
add controlnet pag
Jun 10, 2024
623d237
copy
Jun 10, 2024
f30c2bc
add from pipe test for pag + controlnet
Jun 10, 2024
1df4391
up
Jun 10, 2024
191505e
support guess mode
yiyixuxu Jun 17, 2024
58b8330
style
Jun 17, 2024
71cf2f7
add pag + img2img
Jun 17, 2024
6da3bb6
Merge branch 'main' into pag
sayakpaul Jun 17, 2024
1e79c59
remove guess model support from pag controlnet pipeline
yiyixuxu Jun 20, 2024
14b4ddd
noise_pred_uncond -> noise_pred_text
yiyixuxu Jun 20, 2024
91c41e8
Apply suggestions from code review
yiyixuxu Jun 20, 2024
b72ef1c
fix more
yiyixuxu Jun 20, 2024
b7f4ccd
Merge branch 'main' into pag
Jun 24, 2024
d12b4a0
update docstring example
Jun 24, 2024
28e1301
add copied from
Jun 24, 2024
5653b2a
add doc
Jun 25, 2024
17520f2
up
Jun 25, 2024
e11180a
Merge branch 'main' into pag
Jun 25, 2024
074a4f0
fix copies
Jun 25, 2024
18d8b0e
up
Jun 25, 2024
0e337bf
up
Jun 25, 2024
434f63a
up
Jun 25, 2024
41b1ddc
up
Jun 25, 2024
24cadb4
Update docs/source/en/api/pipelines/pag.md
yiyixuxu Jun 25, 2024
c4ceee9
Apply suggestions from code review
yiyixuxu Jun 25, 2024
9db27cf
Update src/diffusers/models/attention_processor.py
yiyixuxu Jun 25, 2024
8ae87e2
add a tip about extending pag support and explain pag scale
Jun 25, 2024
19eb55f
Merge branch 'pag' of github.com:huggingface/diffusers into pag
Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,15 @@
"StableDiffusionXLAdapterPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLControlNetPipeline",
"StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
Expand Down Expand Up @@ -696,11 +700,15 @@
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
Expand Down
6 changes: 1 addition & 5 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,8 +922,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
Expand Down Expand Up @@ -963,9 +961,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
hidden_size = self.config.block_out_channels[block_id]

if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_processor_class = self.attn_processors[name].__class__
Comment on lines -966 to +964
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Where do we perform the assignment so that self.attn_processor[name].__class__ gives the correct results?

attn_procs[name] = attn_processor_class()

else:
Expand Down
228 changes: 228 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2924,6 +2924,232 @@ def __call__(
return hidden_states


class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)

# original path
batch_size, sequence_length, _ = hidden_states_org.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

value = attn.to_v(hidden_states_ptb)

# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
hidden_states_ptb = value

hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): I think these lines of code could be clubbed together and accompanied with a comment.


if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

# original path
batch_size, sequence_length, _ = hidden_states_org.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down Expand Up @@ -2964,6 +3190,8 @@ def __call__(
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"latent_diffusion": [],
"ledits_pp": [],
"marigold": [],
"pag": [],
"stable_diffusion": [],
"stable_diffusion_xl": [],
}
Expand Down Expand Up @@ -136,6 +137,14 @@
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["pag"].extend(
[
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
]
)
_import_structure["controlnet_xs"].extend(
[
"StableDiffusionControlNetXSPipeline",
Expand Down Expand Up @@ -463,6 +472,12 @@
MarigoldNormalsPipeline,
)
from .musicldm import MusicLDMPipeline
from .pag import (
StableDiffusionXLControlNetPAGPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
Expand Down
40 changes: 17 additions & 23 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
Expand All @@ -361,44 +364,35 @@ def prepare_ip_adapter_image_embeds(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)

return image_embeds
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
Expand Down
Loading
Loading