Skip to content

Commit

Permalink
Whitespace
Browse files Browse the repository at this point in the history
  • Loading branch information
fofr committed Sep 27, 2024
1 parent 01e81a3 commit 126730e
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions lora_loading_patch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from diffusers.utils import convert_unet_state_dict_to_peft, get_peft_kwargs, is_peft_version, get_adapter_name, logging
from diffusers.utils import (
convert_unet_state_dict_to_peft,
get_peft_kwargs,
is_peft_version,
get_adapter_name,
logging,
)

logger = logging.get_logger(__name__)


# patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Expand All @@ -28,7 +37,9 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
k.replace(f"{cls.transformer_name}.", ""): v
for k, v in state_dict.items()
if k in transformer_keys
}

if len(state_dict.keys()) > 0:
Expand All @@ -49,10 +60,20 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
alpha_keys = [
k
for k in network_alphas
if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v
for k, v in network_alphas.items()
if k in alpha_keys
}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
Expand All @@ -67,10 +88,16 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload = (
cls._optionally_disable_offloading(_pipeline)
)

inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, low_cpu_mem_usage=True)
inject_adapter_in_model(
lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True
)
incompatible_keys = set_peft_model_state_dict(
transformer, state_dict, adapter_name, low_cpu_mem_usage=True
)

if incompatible_keys is not None:
# check only for unexpected keys
Expand Down

0 comments on commit 126730e

Please sign in to comment.