Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PAG support #7944

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open

add PAG support #7944

wants to merge 26 commits into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented May 14, 2024

Notes on implementation

separate pipeline class

created a separate pipeline group for PAG so that we are able to support it (and many more such features in the future) while keeping our SD and SDXL pipelines lightweight for the research community

PAGMixin

PAGMixin extracts away all PAG-related logic so that we are able to keep the PAG pipeline structure consistent with the rest of the pipelines. It make it easier to read, and also easier to integrate and maintain

AutoPipeline API

  • You can pass enable_pag =True to automatically create a pipeline with PAG enabled based on the task you specified and the checkpoint you provided. Under the hood, it creates the corresponding PAG pipeline. A few examples
# SDXL + PAG: 
AutoPipelineForText2Image.from_pretrained(repo_id, enable_pag=True ...)

# SDXL + controlnet + PAG: 
AutoPipelineForText2Image.from_pretrained(repo_id, controlnet=controlnet, enable_pag=True ...)

# SDXL Inpainting + PAG: 
AutoPipelineForInpainting.from_pretrained(repo_id, enable_pag=True....)
  • from_pipe API also works and works just intuitively (I hope). A few examples:
# StableDiffusionXLControlNetPipeline to StableDiffusionXLPAGPipeline
AutoPipelineForText2Image.from_pipe(pipe_controlnet, controlnet=None, enable_pag=True)`

# StableDiffusionXLPAGPipeline to StableDiffusionXPipeline: 
AutoPipelineForText2Image.from_pipe(pipe_pag, enable_pag=False)

pag_applied_layers

  • you can set pag_applied_layers when you create the pipeline, e.g.
AutoPipelineForText2Image.from_pretrained(repo_id, enable_pag=True, pag_applied_layers =["down.block_1, "up.block_0.attentions_0"])
  • you can use set_pag_applied_layers to update these layers after the pipeline has been created
pipe_pag.set_pag_applied_layers("mid")
  • The accepted value for set_pag_applied_layers is either a single string or a list of strings, you can
    • set PAG on all the down blocks, middle blocks or up blocks with inputs such as "down", "mid", "up"
    • set PAG on specific down, middle, up blocks with inputs such as "down.block_0", "up.block_1"
    • set PAG on specific attention model on specific down, middle, up blocks with inputs such as "down.block_0.attentions_0"
  • The user will get an error if inputs are not formatted correctly or the layer they specified does not exist in the model.

other notes:

  • I refactored prepare_ip_adapter_image_embeds a little bit so that we duplicate inputs for CFG only once in the end, that's why a lot of the files got changed. you only need to look at the pag folder and auto_pipeline.py file under pipelines folder when reviewing this PR

Usage Examples

SDXL + PAG

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    enable_pag=True,
    pag_applied_layers = ["down.block_2", "up.block_1.attentions_0"],
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()


pag_scales =  [0.0, 3.0]
guidance_scales = [0.0, 7.0]

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
grid = []
for pag_scale in pag_scales:
    for guidance_scale in guidance_scales:
        generator = torch.Generator(device="cpu").manual_seed(0)
        images = pipeline(
            prompt="a polar bear sitting in a chair drinking a milkshake",
            negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
            num_inference_steps=25,
            guidance_scale=guidance_scale,
            generator=generator,
            pag_scale=pag_scale,
        ).images
        images[0]

        grid.append(images[0])

# save the grid
from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(guidance_scales)).save("yiyi_test_2_out.png")

yiyi_test_2_out

SDXL + PAG + IP-Adapter

works with ip-adapter now thanks to @sunovivid

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
from transformers import CLIPVisionModelWithProjection
import torch

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter",
    subfolder="models/image_encoder",
    torch_dtype=torch.float16
)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    image_encoder=image_encoder,
    enable_pag=True,
    torch_dtype=torch.float16
).to("cuda")

pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin")

pag_scales = [0.0, 3.0]
ip_adapter_scales = [0.0, 0.6]

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
grid = []
for pag_scale in pag_scales:
    for ip_adapter_scale in ip_adapter_scales:
        pipeline.set_ip_adapter_scale(ip_adapter_scale)
        generator = torch.Generator(device="cpu").manual_seed(0)
        images = pipeline(
            prompt="a polar bear sitting in a chair drinking a milkshake",
            ip_adapter_image=image,
            negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
            num_inference_steps=25,
            guidance_scale=3.0,
            generator=generator,
            pag_scale=pag_scale,
        ).images
        images[0]

        grid.append(images[0])

# save the grid
from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(ip_adapter_scales)).save("yiyi_test_4_out.png")

pag_ip

SDXL Inpainting + PAG

# pag integration test: inpaint
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForInpainting.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    enable_pag=True,
    #pag_applied_layers = ["down.block_2", "up.block_1.attentions_0"],
    pag_applied_layers = "mid",
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()


img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

prompt = "A majestic tiger sitting on a bench"


pag_scales =  [0.0, 3.0]
guidance_scales = [0.0, 7.5]

grid = []
for pag_scale in pag_scales:
    for guidance_scale in guidance_scales:
        generator = torch.Generator(device="cpu").manual_seed(1)
        images = pipeline(
            prompt=prompt,
            image=init_image,
            mask_image=mask_image,
            strength=0.8,
            num_inference_steps=50,
            guidance_scale=guidance_scale,
            generator=generator,
            pag_scale=pag_scale,
        ).images
        images[0]

        grid.append(images[0])

# save the grid
from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(guidance_scales)).save("yiyi_test_4_out.png")

yiyi_test_4_out

SDXL + ControlNet + PAG

# pag integration test: controlnet + pag
from diffusers import AutoPipelineForText2Image, ControlNetModel, AutoencoderKL
from diffusers.utils import load_image
import numpy as np
import torch

import cv2
from PIL import Image

# initialize the models and pipeline
controlnet_conditioning_scale = 0.5  # recommended for good generalization
controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
)

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    enable_pag=True,
    pag_applied_layers = "mid",
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()
print(f" pipeline: {pipeline.__class__}")


# download an image
image = load_image(
    "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
# get canny image
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = "low quality, bad quality, sketches"


pag_scales =  [0.0, 3.0]
guidance_scales = [0.0, 7.5]


grid = []
for pag_scale in pag_scales:
    for guidance_scale in guidance_scales:
        generator = torch.Generator(device="cpu").manual_seed(1)
        images = pipeline(
            prompt=prompt,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            image=canny_image,
            num_inference_steps=50,
            guidance_scale=guidance_scale,
            generator=generator,
            pag_scale=pag_scale,
        ).images
        images[0]

        grid.append(images[0])

# save the grid
from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(guidance_scales)).save("yiyi_test_5_out.png")

yiyi_test_5_out

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator Author

@asomoza can you test it out?
I tried to make it work with ip-adapter but I don't think it works - do you know if PAG works with ip-adapter?
what other pipelines should I add this too for testing? (

@yiyixuxu
Copy link
Collaborator Author

cc @HyoungwonCho for awareness
also question: does PAG work with IP-adapter?

@yiyixuxu yiyixuxu added contributions-welcome help wanted Extra attention is needed labels May 15, 2024
@asomoza
Copy link
Member

asomoza commented May 15, 2024

I've doing some tests and I like it a lot.

no PAG PAG CFG
20240515013510_925590493 20240515013548_925590493

I think it makes the robot more coherent and it fixes some of the wrong details, but it makes it less "humanoid" and loses a bit of the cinematic look.

I'm still deciding if I like more if we could use a layer or block naming like with the loras and ip_adapter or if pag_applied_layers and pag_applied_layers_index is better. I'll give some examples to evaluate this.

So lets say, I want to test it with what I normally use for the pose in the loras which are all the layers in the down block 2, with the current system I need to do this:

pag_applied_layers_index = ["d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23"]`

the equivalent could be this:

pag_applied_layers = {"down": ["block_2"]}

or for example the last attention block which is what we can associate to the composition with IP Adapters:

pag_applied_layers_index = ["d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23"]`

for this, a equivalent could be:

pag_applied_layers = {"down": "block_2": "attentions_1"}
down_block_2 down_block_2_attentions_1
20240515023111_925590493 20240515023839_925590493

I don't know if going as granular as each of the layers could bring a benefit, even someone like me that likes full control won't go as far as to try to control an image with 70 different layers on top of everything else.

As an example, as an advanced user, I want to use PAG to make the image better but without the robot losing it's humanoid form and the cinematic look.

Doing some quick tests, I found that for this particular image, this works really well:

pipeline.enable_pag(
    pag_scale=3.0,
    pag_applied_layers=None,
    pag_applied_layers_index=[
        "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "u0", "u1", "u2", "u3", "u4", "u5", "u6", "u7", "u8", "u9",
    ],
)

which in the lora format would be like this:

pag_applied_layers = {"down": "block_2", "up": "block_1": "attentions_0"}

20240515025029_925590493

Hope this example is somewhat clear, and also we can see that it matters a lot, the image is a lot better with this.

I'll do tests with the other use cases later, specially with the upscaler.

@HyoungwonCho
Copy link
Contributor

HyoungwonCho commented May 15, 2024

@yiyixuxu @asomoza Hello, I was impressed by the various experiments you conducted using PAG!
We are also discussing the use of PAG in various tasks, as well as layer/scale selection.

Since the guidance framework of PAG itself is simple, it seems quite possible to use it in conjunction with other modules like the IP-Adapter you mentioned. However, we have not yet implemented and experimented with it directly, so we have not confirmed whether there is a significant performance improvement when used together. If possible, we will conduct additional experiments in the future.

Thank you for your interest in our research.

@KKIEEK
Copy link

KKIEEK commented May 15, 2024

Thank you for the great work!
However, I encountered the following issue when using StableDiffusionXLControlNetPipeline with CFG and PAG:

  File ".../.env/lib/python3.11/site-packages/diffusers/models/controlnet.py", line 798, in forward
    sample = sample + controlnet_cond
             ~~~~~~~^~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0

I solved it by adding a new parameter do_perturbed_attention_guidance and appending the following lines in the prepare_image method.

        if do_classifier_free_guidance and do_perturbed_attention_guidance and not guess_mode:
            image = torch.cat([image] * 3)
        elif do_classifier_free_guidance and not guess_mode:
            image = torch.cat([image] * 2)
        elif do_perturbed_attention_guidance and not guess_mode:
            image = torch.cat([image] * 2)

@yiyixuxu
Copy link
Collaborator Author

@KKIEEK
thanks! I added your change:)

@jorgemcgomes
Copy link
Contributor

jorgemcgomes commented May 16, 2024

Just leaving a brief report of my findings with PAG and Diffusers (I already had it integrated in my pipelines before this PR):

  • It generally works very very well when properly tuned. Almost looks like a significant model upgrade.
  • I'm using it with models derived from SD2.1.
  • Implemented it successfuly in text-to-image, image-to-image, controlnet, unclip, and inpainting pipelines.
  • I get the best results with values around guidance_scale=7 and pag_scale=3
  • The layers to which it is applied makes a huge difference on the output. It's the difference between garbage and excellent. Adding or removing a single layer can make it or break it.
  • For example, for SD2.1, I found that with just [m0] the effect was too subtle, [d4, d5, m0] was overcooked, [d5, m0] seems to work best; adding any up layers typically screws up the results [d5, m0, u0].
  • The applied layers will obviously change in different model architectures. And I imagine that the "optimal" layers might even change with fine-tunes. I couldn't replicate the optimal parameters described in the paper (for SD1.5), with SD2.1 (which has the same unet architecture).

@yiyixuxu
Copy link
Collaborator Author

@jorgemcgomes thanks!

@sunovivid
Copy link

Hello. I'm an author of PAG. Thank you for your insightful opinions and cool implementation. Is there anything currently in progress? We are excited to see that PAG is gaining popularity within the community and being utilized in various workflows. Especially in ComfyUI, PAG nodes are used in diverse workflows.

(Some workflows using PAG in ComfyUI:
https://www.reddit.com/r/StableDiffusion/comments/1c68qao/perturbedattention_guidance_really_helps_with/
https://civitai.com/models/141592/pixelwave
https://civitai.com/models/413564/cjs-super-simple-high-detail-cosxl-and-pag-workflow
https://www.reddit.com/r/StableDiffusion/comments/1c4cb3l/improve_stable_diffusion_prompt_following_image/
https://www.reddit.com/r/StableDiffusion/comments/1ck69az/make_it_good_options_in_stable_diffusion/
https://stable-diffusion-art.com/perturbed-attention-guidance/)

However, in Diffusers, it seems somewhat challenging to try creative combinations as the pipelines are separated.
( a collection of PAG pipelines with Diffusers: https://x.com/multimodalart/status/1788844183760847106 )

Therefore, the MixIn approach taken in this PR appears to be a very effective solution. However, it seems a bit awkward to call enable_pag every time to adjust the pag scale. Ideally, it would be more natural to set the pag_scale when calling the pipeline after enable_pag (similar to setting ip_adapter_image=image after in load_ip_adapter). So, I'm exploring a better design for this.

Additionally, since there are many users who want compatibility with IP-adapter, now I have time and would like to work on making it compatible with IPAdapter. I'm curious if there's any related progress about component design or IP-adapter compatibility.

Thank you!

@yiyixuxu
Copy link
Collaborator Author

@sunovivid thanks for the message!
this is not the finalized design just something we can use to test out compatibility of PAG - we will iterate on the final design

for IP-adapter, it will be super cool if we can make it work! I'm not aware of any related progress so would really appreciate if you are able to find time to work on this! maybe we can just pick one of the pipelines from this PR (with the mixin) and make it work with ip_adpter_image input?

@sunovivid
Copy link

sunovivid commented Jun 2, 2024

@yiyixuxu Hi! I made a working version of PAG + IP-adapter. Can you check the PR?

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jun 3, 2024

@sunovivid we will merge in and work on a new design for PAG once you upload the new change for ip-adapter :)

for pag_applied_layers:

  1. I think we should use the lora format, let me know what you think @sunovivid: see @asomoza 's comments and experiments here add PAG support  #7944 (comment); you can also find more about the scale dict we support in ip-adapter and lora here and here
  2. is pag_applied_layers something we would want to change a lot for different generations? i.e. can we make it a pipeline config/attribute instead of a call argument? I think we will have to make pag_scale a call argument

@sunovivid
Copy link

sunovivid commented Jun 4, 2024

Hi @yiyixuxu,

Thank you for the feedback!

I might have misunderstood something. Should I upload the new changes for the ip-adapter in this PR? How can I upload the changes? Should I attach files or use another approach?

for pag_applied_layers:

  1. Completely agree! For user convenience, the overall code should consistently follow the conventions used in the Diffusers codebase.
  2. I believe once the best choice for pag_applied_layers is determined per model through experiments (like the great example you provided in @asomoza's comment), it likely won't need frequent changes. Users will likely follow the recommended approach for each model. I also agree that pag_scale should be a call argument.

@@ -508,6 +508,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this method a little bit
this test run

import torch

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image


pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
)

pipeline.load_ip_adapter(
    "h94/IP-Adapter",
    subfolder="sdxl_models",
    weight_name=[
        "ip-adapter_sdxl_vit-h.safetensors",
        "ip-adapter-plus_sdxl_vit-h.safetensors",
        "ip-adapter-plus-face_sdxl_vit-h.safetensors",
    ],
    image_encoder_folder="models/image_encoder",
)
pipeline.set_ip_adapter_scale([0.1, 0.7, 0.3])
pipeline.to("cuda")

face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/women_input.png")
style_folder = "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy"
style_images = [load_image(f"{style_folder}/img{i}.png") for i in range(10)]

prompt = "wonderwoman"
num_images_per_prompt = 1
guidance_scale = 7.5
do_classifier_free_guidance = guidance_scale > 1

generator = torch.Generator(device="cuda").manual_seed(0)
image = pipeline(
    prompt=prompt,
    ip_adapter_image=[face_image, style_images, face_image],
    negative_prompt="",
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt,
    generator = generator,
).images[0]

image.save("yiyi_test_12_out_imgs.png")

with torch.no_grad():
    image_embeds = pipeline.prepare_ip_adapter_image_embeds(
        [face_image, style_images, face_image],
        None,
        "cuda",
        num_images_per_prompt,
        do_classifier_free_guidance,
    )

generator = torch.Generator(device="cuda").manual_seed(0)
image = pipeline(
    prompt=prompt,
    ip_adapter_image_embeds=image_embeds,
    negative_prompt="",
    guidance_scale=guidance_scale,
    num_images_per_prompt=num_images_per_prompt,
    generator=generator,
).images[0]
image.save("yiyi_test_12_out_img_embeds.png")

yiyi_test_12_out_img_embeds

@yiyixuxu
Copy link
Collaborator Author

@HyoungwonCho @sunovivid
this PR is ready for a final review now! I would appreciate it if you could also take a look!
I updated the PR description #7944 (comment)

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Jun 10, 2024

cc @apolinario and @vladmandic

we plan to support more popular features like PAG in diffusers, so design-wise, this PR sets the example for the future PRs. Would appreciate your inputs too:)

@vladmandic
Copy link
Contributor

thanks @yiyixuxu

from a quick glance, new "magic" is mostly in src/diffusers/pipelines/auto_pipeline.py triggered on kwargs.

PAG itself is still a separate pipeline and can be used as a separate pipeline, its just that autopipeline will do automatic switching if enable_pag is in kwargs:

orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")

i'm ok with that, one potential issue is propagation of future fixes - e.g. if there is a fix created for somewhere in StableDiffusionPipeline and autopipeline does behind-the-scene switch to StableDiffusionPAGPipeline, then we really need to ensure there are no regressions there since user is not even explicitly aware of that switch

just not sure about the mappings using string replace - ok for PAG, but would this pattern apply universally?

text_2_image_cls.name.replace("PAG", "").replace("Pipeline", "PAGPipeline"),

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions-welcome help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants