Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 6 #7

Open
balala8 opened this issue Sep 18, 2024 · 9 comments

Comments

@balala8
Copy link

balala8 commented Sep 18, 2024

In order to reduce the memory usage, I use optimize.quanto to quantize transformer, controlnet, and t5encoder in fp8, but I encounter an error

File "/home/yongfang/miniconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 616, in apply_rotary_emb
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 6

my code is here:

import torch
from diffusers.utils import load_image, check_min_version
from controlnet_flux import FluxControlNetModel
from transformer_flux import FluxTransformer2DModel
from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline
from optimum.quanto import freeze, qfloat8, quantize,QuantizedTransformersModel,QuantizedDiffusersModel
from huggingface_hub import login
from transformers import T5EncoderModel, CLIPTextModel
from PIL import Image

check_min_version("0.30.2")

dtype = torch.float16

class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
    base_class = FluxTransformer2DModel

class QuatizedControlNetModel(QuantizedDiffusersModel):
    base_class = FluxControlNetModel

class QuantizedT5EncoderModelForCausalLM(QuantizedTransformersModel):
    auto_class = T5EncoderModel
    auto_class.from_config = auto_class._from_config

# Build pipeline
controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.float16)
quantize(controlnet, weights=qfloat8)
freeze(controlnet)
transformer = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="transformer", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained("black-forest-labs/FLUX.1-dev",subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = FluxControlNetInpaintingPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    controlnet=None,
    transformer=None,
    text_encoder_2=None,
    torch_dtype=torch.float16
).to("cuda")
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.controlnet = controlnet

pipe = pipe.to(device="cuda")

pipe.controlnet.to(torch.float16)


# Set image path , mask path and prompt
image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png'
mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg'
prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it'

# Load image and mask
size = (768, 768)
image = load_image(image_path).convert("RGB").resize(size)
mask = load_image(mask_path).convert("RGB").resize(size)
generator = torch.Generator(device="cuda").manual_seed(24)

# Inpaint
result = pipe(
    prompt=prompt,
    height=size[1],
    width=size[0],
    control_image=image,
    control_mask=mask,
    num_inference_steps=28,
    generator=generator,
    controlnet_conditioning_scale=0.9,
    guidance_scale=3.5,
    negative_prompt="",
    true_guidance_scale=3.5
).images[0]

result.save('flux_inpaint.png')
print("Successfully inpaint image")

And the complete error message is as follows:

Traceback (most recent call last):
  File "/home/documents/train_ic_light/FLUX-Controlnet-Inpainting/main.py", line 78, in <module>
    prompt=prompt,
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/documents/train_ic_light/FLUX-Controlnet-Inpainting/pipeline_flux_controlnet_inpaint.py", line 956, in __call__
    ) = self.controlnet(
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/documents/train_ic_light/FLUX-Controlnet-Inpainting/controlnet_flux.py", line 332, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/documents/train_ic_light/FLUX-Controlnet-Inpainting/transformer_flux.py", line 214, in forward
    attn_output, context_attn_output = self.attn(
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 490, in forward
    return self.processor(
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1762, in __call__
    query = apply_rotary_emb(query, image_rotary_emb)
  File "/home/miniconda3/envs/diffusers/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 616, in apply_rotary_emb
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
RuntimeError: The size of tensor a (128) must match the size of tensor b (2) at non-singleton dimension 6

I printed the shapes of some variables in apply_rotary_emb function , as follows,

x shape: torch.Size([2, 24, 2816, 128])
cos shape: torch.Size([1, 1, 1, 2816, 64, 2, 2])
sin shape: torch.Size([1, 1, 1, 2816, 64, 2, 2])
x_rotated shape: torch.Size([2, 24, 2816, 128])
@matabear-wyx
Copy link

I tried the same thing, I guess their pipeline just doesn't support quantized Transformer yet.
Use cpu offload instead, use around 31G VRAM

@balala8
Copy link
Author

balala8 commented Sep 19, 2024

I tried the same thing, I guess their pipeline just doesn't support quantized Transformer yet. Use cpu offload instead, use around 31G VRAM

It's strange, isn't it? Why is it that after the model is quantized, the emb shape doesn't match instead of other errors? 31G VRAM also exceeds my limit.

@matabear-wyx
Copy link

I tried the same thing, I guess their pipeline just doesn't support quantized Transformer yet. Use cpu offload instead, use around 31G VRAM

It's strange, isn't it? Why is it that after the model is quantized, the emb shape doesn't match instead of other errors? 31G VRAM also exceeds my limit.

well, I assume their pipeline are data type sensitive, you may notice that in their sample code, they do:

transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16
)
pipe.transformer.to(torch.bfloat16)

if you comment the second line, there will be data type unmatch error. So I guess quantized models are not supported yet for this pipeline.

@Wh0ru
Copy link

Wh0ru commented Sep 24, 2024

I also encountered this problem when I ran the code in main.py directly. Have you ever had this problem without quantize?

@matabear-wyx
Copy link

I also encountered this problem when I ran the code in main.py directly. Have you ever had this problem without quantize?

Try downgrade your diffusers==0.30.2 and transformers==4.42.0

@Wh0ru
Copy link

Wh0ru commented Sep 24, 2024

I also encountered this problem when I ran the code in main.py directly. Have you ever had this problem without quantize?

Try downgrade your diffusers==0.30.2 and transformers==4.42.0

OK, I'll try. Thanks!

@gxground
Copy link

升级COMFYUI就这样

@vbuterin
Copy link

vbuterin commented Oct 31, 2024

Poking into line ~1780 of ...../site-packages/diffusers/models/attention_processor.py in diffusers 0.31 and changing the code to this (essentially, bringing back a piece of diffusers 0.30.2) seems to solve the issue.

Obviously this is a hacky solution and ideally diffusers itself would fix what's going on, or people can figure out some kind of workaround outside the diffusers library.

        # YiYi to-do: refactor rope related functions/classes
        def apply_rope(xq, xk, freqs_cis):
            xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
            xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
            xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
            xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
            return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

        if image_rotary_emb is not None:
            from .embeddings import apply_rotary_emb

            query, key = apply_rope(query, key, image_rotary_emb)
            #query = apply_rotary_emb(query, image_rotary_emb)
            #key = apply_rotary_emb(key, image_rotary_emb)

@brurpo
Copy link

brurpo commented Nov 1, 2024

@vbuterin that worked wonders, thank you
Edit: It unfortunatelly breaks the other pipelines.
A VERY dirty way of fixing this until diffusers come up with a solution and make it work with all the pipelines would be to do the following:

def apply_rope(xq, xk, freqs_cis):
                xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
                xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
                xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
                xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
                return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

        if image_rotary_emb is not None:
            from .embeddings import apply_rotary_emb

            try:
                query, key = apply_rope(query, key, image_rotary_emb)
            except:
                query = apply_rotary_emb(query, image_rotary_emb)
                key = apply_rotary_emb(key, image_rotary_emb)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants