diff --git a/lora_loading_patch.py b/lora_loading_patch.py index 033e1c3..5e10eee 100644 --- a/lora_loading_patch.py +++ b/lora_loading_patch.py @@ -1,10 +1,19 @@ # ruff: noqa -from diffusers.utils import convert_unet_state_dict_to_peft, get_peft_kwargs, is_peft_version, get_adapter_name, logging +from diffusers.utils import ( + convert_unet_state_dict_to_peft, + get_peft_kwargs, + is_peft_version, + get_adapter_name, + logging, +) logger = logging.get_logger(__name__) + # patching inject_adapter_in_model and load_peft_state_dict with low_cpu_mem_usage=True until it's merged into diffusers -def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): +def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None +): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -29,7 +38,9 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + k.replace(f"{cls.transformer_name}.", ""): v + for k, v in state_dict.items() + if k in transformer_keys } if len(state_dict.keys()) > 0: @@ -50,10 +61,20 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada if network_alphas is not None and len(network_alphas) >= 1: prefix = cls.transformer_name - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + alpha_keys = [ + k + for k in network_alphas.keys() + if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v + for k, v in network_alphas.items() + if k in alpha_keys + } - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) + lora_config_kwargs = get_peft_kwargs( + rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict + ) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): raise ValueError( @@ -69,10 +90,16 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + is_model_cpu_offload, is_sequential_cpu_offload = ( + cls._optionally_disable_offloading(_pipeline) + ) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, low_cpu_mem_usage=True) + inject_adapter_in_model( + lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=True + ) + incompatible_keys = set_peft_model_state_dict( + transformer, state_dict, adapter_name, low_cpu_mem_usage=True + ) if incompatible_keys is not None: # check only for unexpected keys diff --git a/predict.py b/predict.py index f906dfa..866cb99 100644 --- a/predict.py +++ b/predict.py @@ -102,7 +102,9 @@ def setup(self) -> None: # pyright: ignore "FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") - dev_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer) + dev_pipe.__class__.load_lora_into_transformer = classmethod( + load_lora_into_transformer + ) print("Loading Flux schnell pipeline") if not FLUX_SCHNELL_PATH.exists(): @@ -133,7 +135,9 @@ def setup(self) -> None: # pyright: ignore tokenizer=dev_pipe.tokenizer, tokenizer_2=dev_pipe.tokenizer_2, ).to("cuda") - dev_img2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer) + dev_img2img_pipe.__class__.load_lora_into_transformer = classmethod( + load_lora_into_transformer + ) print("Loading Flux schnell img2img pipeline") schnell_img2img_pipe = FluxImg2ImgPipeline( @@ -162,7 +166,9 @@ def setup(self) -> None: # pyright: ignore tokenizer=dev_pipe.tokenizer, tokenizer_2=dev_pipe.tokenizer_2, ).to("cuda") - dev_inpaint_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer) + dev_inpaint_pipe.__class__.load_lora_into_transformer = classmethod( + load_lora_into_transformer + ) print("Loading Flux schnell inpaint pipeline") schnell_inpaint_pipe = FluxInpaintPipeline( diff --git a/ruff.toml b/ruff.toml index 0f7b5c0..f7e5bbf 100644 --- a/ruff.toml +++ b/ruff.toml @@ -27,7 +27,7 @@ exclude = [ "site-packages", "venv", "ai-toolkit", - "LLaVA" + "LLaVA", ] # Same as Black.