Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Sep 27, 2024
1 parent 34b4947 commit e636fb8
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 13 deletions.
45 changes: 36 additions & 9 deletions lora_loading_patch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# ruff: noqa
from diffusers.utils import convert_unet_state_dict_to_peft, get_peft_kwargs, is_peft_version, get_adapter_name, logging
from diffusers.utils import (
convert_unet_state_dict_to_peft,
get_peft_kwargs,
is_peft_version,
get_adapter_name,
logging,
)

logger = logging.get_logger(__name__)


# patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Expand All @@ -29,7 +38,9 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
state_dict = {
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
k.replace(f"{cls.transformer_name}.", ""): v
for k, v in state_dict.items()
if k in transformer_keys
}

if len(state_dict.keys()) > 0:
Expand All @@ -50,10 +61,20 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
alpha_keys = [
k
for k in network_alphas.keys()
if k.startswith(prefix) and k.split(".")[0] == prefix
]
network_alphas = {
k.replace(f"{prefix}.", ""): v
for k, v in network_alphas.items()
if k in alpha_keys
}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
Expand All @@ -69,10 +90,16 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload = (
cls._optionally_disable_offloading(_pipeline)
)

inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, low_cpu_mem_usage=True)
inject_adapter_in_model(
lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True
)
incompatible_keys = set_peft_model_state_dict(
transformer, state_dict, adapter_name, low_cpu_mem_usage=True
)

if incompatible_keys is not None:
# check only for unexpected keys
Expand Down
12 changes: 9 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def setup(self) -> None: # pyright: ignore
"FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
dev_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
dev_pipe.__class__.load_lora_into_transformer = classmethod(
load_lora_into_transformer
)

print("Loading Flux schnell pipeline")
if not FLUX_SCHNELL_PATH.exists():
Expand Down Expand Up @@ -133,7 +135,9 @@ def setup(self) -> None: # pyright: ignore
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
).to("cuda")
dev_img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
dev_img2img_pipe.__class__.load_lora_into_transformer = classmethod(
load_lora_into_transformer
)

print("Loading Flux schnell img2img pipeline")
schnell_img2img_pipe = FluxImg2ImgPipeline(
Expand Down Expand Up @@ -162,7 +166,9 @@ def setup(self) -> None: # pyright: ignore
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
).to("cuda")
dev_inpaint_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
dev_inpaint_pipe.__class__.load_lora_into_transformer = classmethod(
load_lora_into_transformer
)

print("Loading Flux schnell inpaint pipeline")
schnell_inpaint_pipe = FluxInpaintPipeline(
Expand Down
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = [
"site-packages",
"venv",
"ai-toolkit",
"LLaVA"
"LLaVA",
]

# Same as Black.
Expand Down

0 comments on commit e636fb8

Please sign in to comment.