From 1fbc9361e47baa96b8823e9fc9c6814926a7170c Mon Sep 17 00:00:00 2001 From: Pranav Veldurthi Date: Sun, 1 Sep 2024 14:34:35 -0400 Subject: [PATCH 1/2] Fix attention layers map for SD-2-1-Base --- stable_diffusion/stable_diffusion/model_io.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index 2c2227db8..ea85a2180 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -140,6 +140,17 @@ def map_vae_weights(key, value): if "to_v" in key: key = key.replace("to_v", "value_proj") + + # Map attention layers in SD-2-1-base:VAE + if "key" in key: + key = key.replace("key", "to_k") + if "proj_attn" in key: + key = key.replace("proj_attn", "out_proj") + if "query" in key: + key = key.replace("query", "query_proj") + if "value" in key: + key = key.replace("value", "value_proj") + # Map the mid block if "mid_block.resnets.0" in key: key = key.replace("mid_block.resnets.0", "mid_blocks.0") From 282304a87d9caf71b0da1d02c585905241160e4c Mon Sep 17 00:00:00 2001 From: Pranav Veldurthi Date: Wed, 4 Sep 2024 09:09:52 -0400 Subject: [PATCH 2/2] fix mapping for sdxl turbo --- stable_diffusion/stable_diffusion/model_io.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/stable_diffusion/stable_diffusion/model_io.py b/stable_diffusion/stable_diffusion/model_io.py index ea85a2180..23aa82e2b 100644 --- a/stable_diffusion/stable_diffusion/model_io.py +++ b/stable_diffusion/stable_diffusion/model_io.py @@ -130,7 +130,17 @@ def map_vae_weights(key, value): if "upsamplers" in key: key = key.replace("upsamplers.0.conv", "upsample") - # Map attention layers + # Map attention layers in SD-2-1-base:VAE + if "key" in key: + key = key.replace("key", "key_proj") + if "proj_attn" in key: + key = key.replace("proj_attn", "out_proj") + if "query" in key: + key = key.replace("query", "query_proj") + if "value" in key: + key = key.replace("value", "value_proj") + + # Map attention layers in SDXL Turbo if "to_k" in key: key = key.replace("to_k", "key_proj") if "to_out.0" in key: @@ -141,16 +151,6 @@ def map_vae_weights(key, value): key = key.replace("to_v", "value_proj") - # Map attention layers in SD-2-1-base:VAE - if "key" in key: - key = key.replace("key", "to_k") - if "proj_attn" in key: - key = key.replace("proj_attn", "out_proj") - if "query" in key: - key = key.replace("query", "query_proj") - if "value" in key: - key = key.replace("value", "value_proj") - # Map the mid block if "mid_block.resnets.0" in key: key = key.replace("mid_block.resnets.0", "mid_blocks.0")