diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 338175b75163..df1d351ca1f7 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ network_alphas = {} # Check for DoRA-enabled LoRAs. - if any( - "dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k) - for k in state_dict - ): + dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) + dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) + dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) + if dora_present_in_unet or dora_present_in_te or dora_present_in_te2: if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." @@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. - if "dora_scale" in state_dict: + if dora_present_in_unet: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." unet_state_dict[ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") @@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) # Store DoRA scale if present. - if "dora_scale" in state_dict: + if dora_present_in_te or dora_present_in_te2: dora_scale_key_to_replace_te = ( "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) @@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_ if len(state_dict) > 0: raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") - logger.info("Kohya-style checkpoint detected.") + logger.info("Non-diffusers checkpoint detected.") # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}