Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update handle single blocks on _convert_xlabs_flux_lora_to_diffusers #9915

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

raulmosa
Copy link

What does this PR do?

Proposal to update the following script for Xlab Flux LoRA conversion due to a mismatch between keys in the state dictionary.
src/diffusers/loaders/lora_conversion_utils.py
When mapping single_blocks layers, if the model trained in Flux contains single_blocks, these keys are not updated and removed from the old_state_dict, see lines 635-655. And the ValueError is reached:

 if len(old_state_dict) > 0:
    raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

See example, keys from Flux LoRA model working (XLabs-AI/flux-RealismLora), it doesn’t contain single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight']

And below an example of a LoRA trained with current Xlabs code containing single_blocks:
['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight', 'double_blocks.1.processor.proj_lora1.down.weight', 'double_blocks.1.processor.proj_lora1.up.weight', 'double_blocks.1.processor.proj_lora2.down.weight', 'double_blocks.1.processor.proj_lora2.up.weight', 'double_blocks.1.processor.qkv_lora1.down.weight', 'double_blocks.1.processor.qkv_lora1.up.weight', 'double_blocks.1.processor.qkv_lora2.down.weight', 'double_blocks.1.processor.qkv_lora2.up.weight', 'double_blocks.10.processor.proj_lora1.down.weight', 'double_blocks.10.processor.proj_lora1.up.weight', 'double_blocks.10.processor.proj_lora2.down.weight', 'double_blocks.10.processor.proj_lora2.up.weight', 'double_blocks.10.processor.qkv_lora1.down.weight', 'double_blocks.10.processor.qkv_lora1.up.weight', 'double_blocks.10.processor.qkv_lora2.down.weight', 'double_blocks.10.processor.qkv_lora2.up.weight', 'double_blocks.11.processor.proj_lora1.down.weight', 'double_blocks.11.processor.proj_lora1.up.weight', 'double_blocks.11.processor.proj_lora2.down.weight', 'double_blocks.11.processor.proj_lora2.up.weight', 'double_blocks.11.processor.qkv_lora1.down.weight', 'double_blocks.11.processor.qkv_lora1.up.weight', 'double_blocks.11.processor.qkv_lora2.down.weight', 'double_blocks.11.processor.qkv_lora2.up.weight', 'double_blocks.12.processor.proj_lora1.down.weight', 'double_blocks.12.processor.proj_lora1.up.weight', 'double_blocks.12.processor.proj_lora2.down.weight', 'double_blocks.12.processor.proj_lora2.up.weight', 'double_blocks.12.processor.qkv_lora1.down.weight', 'double_blocks.12.processor.qkv_lora1.up.weight', 'double_blocks.12.processor.qkv_lora2.down.weight', 'double_blocks.12.processor.qkv_lora2.up.weight', 'double_blocks.13.processor.proj_lora1.down.weight', 'double_blocks.13.processor.proj_lora1.up.weight', 'double_blocks.13.processor.proj_lora2.down.weight', 'double_blocks.13.processor.proj_lora2.up.weight', 'double_blocks.13.processor.qkv_lora1.down.weight', 'double_blocks.13.processor.qkv_lora1.up.weight', 'double_blocks.13.processor.qkv_lora2.down.weight', 'double_blocks.13.processor.qkv_lora2.up.weight', 'double_blocks.14.processor.proj_lora1.down.weight', 'double_blocks.14.processor.proj_lora1.up.weight', 'double_blocks.14.processor.proj_lora2.down.weight', 'double_blocks.14.processor.proj_lora2.up.weight', 'double_blocks.14.processor.qkv_lora1.down.weight', 'double_blocks.14.processor.qkv_lora1.up.weight', 'double_blocks.14.processor.qkv_lora2.down.weight', 'double_blocks.14.processor.qkv_lora2.up.weight', 'double_blocks.15.processor.proj_lora1.down.weight', 'double_blocks.15.processor.proj_lora1.up.weight', 'double_blocks.15.processor.proj_lora2.down.weight', 'double_blocks.15.processor.proj_lora2.up.weight', 'double_blocks.15.processor.qkv_lora1.down.weight', 'double_blocks.15.processor.qkv_lora1.up.weight', 'double_blocks.15.processor.qkv_lora2.down.weight', 'double_blocks.15.processor.qkv_lora2.up.weight', 'double_blocks.16.processor.proj_lora1.down.weight', 'double_blocks.16.processor.proj_lora1.up.weight', 'double_blocks.16.processor.proj_lora2.down.weight', 'double_blocks.16.processor.proj_lora2.up.weight', 'double_blocks.16.processor.qkv_lora1.down.weight', 'double_blocks.16.processor.qkv_lora1.up.weight', 'double_blocks.16.processor.qkv_lora2.down.weight', 'double_blocks.16.processor.qkv_lora2.up.weight', 'double_blocks.17.processor.proj_lora1.down.weight', 'double_blocks.17.processor.proj_lora1.up.weight', 'double_blocks.17.processor.proj_lora2.down.weight', 'double_blocks.17.processor.proj_lora2.up.weight', 'double_blocks.17.processor.qkv_lora1.down.weight', 'double_blocks.17.processor.qkv_lora1.up.weight', 'double_blocks.17.processor.qkv_lora2.down.weight', 'double_blocks.17.processor.qkv_lora2.up.weight', 'double_blocks.18.processor.proj_lora1.down.weight', 'double_blocks.18.processor.proj_lora1.up.weight', 'double_blocks.18.processor.proj_lora2.down.weight', 'double_blocks.18.processor.proj_lora2.up.weight', 'double_blocks.18.processor.qkv_lora1.down.weight', 'double_blocks.18.processor.qkv_lora1.up.weight', 'double_blocks.18.processor.qkv_lora2.down.weight', 'double_blocks.18.processor.qkv_lora2.up.weight', 'double_blocks.2.processor.proj_lora1.down.weight', 'double_blocks.2.processor.proj_lora1.up.weight', 'double_blocks.2.processor.proj_lora2.down.weight', 'double_blocks.2.processor.proj_lora2.up.weight', 'double_blocks.2.processor.qkv_lora1.down.weight', 'double_blocks.2.processor.qkv_lora1.up.weight', 'double_blocks.2.processor.qkv_lora2.down.weight', 'double_blocks.2.processor.qkv_lora2.up.weight', 'double_blocks.3.processor.proj_lora1.down.weight', 'double_blocks.3.processor.proj_lora1.up.weight', 'double_blocks.3.processor.proj_lora2.down.weight', 'double_blocks.3.processor.proj_lora2.up.weight', 'double_blocks.3.processor.qkv_lora1.down.weight', 'double_blocks.3.processor.qkv_lora1.up.weight', 'double_blocks.3.processor.qkv_lora2.down.weight', 'double_blocks.3.processor.qkv_lora2.up.weight', 'double_blocks.4.processor.proj_lora1.down.weight', 'double_blocks.4.processor.proj_lora1.up.weight', 'double_blocks.4.processor.proj_lora2.down.weight', 'double_blocks.4.processor.proj_lora2.up.weight', 'double_blocks.4.processor.qkv_lora1.down.weight', 'double_blocks.4.processor.qkv_lora1.up.weight', 'double_blocks.4.processor.qkv_lora2.down.weight', 'double_blocks.4.processor.qkv_lora2.up.weight', 'double_blocks.5.processor.proj_lora1.down.weight', 'double_blocks.5.processor.proj_lora1.up.weight', 'double_blocks.5.processor.proj_lora2.down.weight', 'double_blocks.5.processor.proj_lora2.up.weight', 'double_blocks.5.processor.qkv_lora1.down.weight', 'double_blocks.5.processor.qkv_lora1.up.weight', 'double_blocks.5.processor.qkv_lora2.down.weight', 'double_blocks.5.processor.qkv_lora2.up.weight', 'double_blocks.6.processor.proj_lora1.down.weight', 'double_blocks.6.processor.proj_lora1.up.weight', 'double_blocks.6.processor.proj_lora2.down.weight', 'double_blocks.6.processor.proj_lora2.up.weight', 'double_blocks.6.processor.qkv_lora1.down.weight', 'double_blocks.6.processor.qkv_lora1.up.weight', 'double_blocks.6.processor.qkv_lora2.down.weight', 'double_blocks.6.processor.qkv_lora2.up.weight', 'double_blocks.7.processor.proj_lora1.down.weight', 'double_blocks.7.processor.proj_lora1.up.weight', 'double_blocks.7.processor.proj_lora2.down.weight', 'double_blocks.7.processor.proj_lora2.up.weight', 'double_blocks.7.processor.qkv_lora1.down.weight', 'double_blocks.7.processor.qkv_lora1.up.weight', 'double_blocks.7.processor.qkv_lora2.down.weight', 'double_blocks.7.processor.qkv_lora2.up.weight', 'double_blocks.8.processor.proj_lora1.down.weight', 'double_blocks.8.processor.proj_lora1.up.weight', 'double_blocks.8.processor.proj_lora2.down.weight', 'double_blocks.8.processor.proj_lora2.up.weight', 'double_blocks.8.processor.qkv_lora1.down.weight', 'double_blocks.8.processor.qkv_lora1.up.weight', 'double_blocks.8.processor.qkv_lora2.down.weight', 'double_blocks.8.processor.qkv_lora2.up.weight', 'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight', 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight']

The script works changing lines 639-642 by:

if "proj_lora" in old_key:
  new_key += ".proj_out"
elif "qkv_lora" in old_key and "up" not in old_key:
  handle_qkv(old_state_dict, new_state_dict, old_key, [
    f"transformer.single_transformer_blocks.{block_num}.norm.linear"
  ])

Related PR #9295 (@sayakpaul )

Reproduction

import torch
from diffusers import DiffusionPipeline


model_path = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
lora_model_path = "XLabs-AI/flslux-RealismLora"
# lora_model_path = "<PATH-LoRA-trained-Xlabs.safetensors>"
pipe.load_lora_weights(lora_model_path, adapter_name="lora_A")

Logs

When a custom LoRA trained with Xlabs code containing single_blocks is loaded:

File "/home/.pyenv/versions/xflux/lib/python3.10/site-packages/diffusers/loaders/lora_conversion_utils.py", line 658, in _convert_xlabs_flux_lora_to_diffusers
    raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
ValueError: `old_state_dict` should be at this point but has: ['single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'].

System Info

  • 🤗 Diffusers version: 0.31.0
  • Platform: Linux-5.10.0-33-cloud-amd64-x86_64-with-glibc2.31
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.5
  • Transformers version: 4.43.3
  • Accelerate version: 0.30.1
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can review?

@sayakpaul @yiyixuxu

…to fix bug on updating keys and old_state_dict
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work!

Can you run

def test_flux_xlabs(self):

to see if we're not introducing any breaking changes? You will have to comment out

@unittest.skip("We cannot run inference on this model with the current CI hardware")

Additionally, for the record, lora_model_path = "XLabs-AI/flslux-RealismLora" is not a reproduction of the issue that this PR tries to solve as the LoRA inside that repo doesn't have any single transformer blocks.

if "proj_lora1" in old_key or "proj_lora2" in old_key:
# if "proj_lora1" in old_key or "proj_lora2" in old_key:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's with the comments?

Copy link
Author

@raulmosa raulmosa Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is the previous code and should be removed.
The reason for the change in this part of the code is that single blocks in Xlabs Flux LoRA do not contain "pros_lora1" or "proj_lora2", the string is "proj_lora".
See the example below of old_state_dict of a LoRA model where single blocks 1 to 4 are trained (only keys for double block 9 and single blocks are shown):

  • Double blocks keys example:
'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight',
  • Single blocks key example:
 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'

Then if we use the previous line code, single_blocks will never be updated in new_state_dict and removed from old_state_dict.

Comment on lines 641 to 646
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
new_key += ".norm.linear"
# elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
elif "qkv_lora" in old_key and "up" not in old_key:
handle_qkv(old_state_dict, new_state_dict, old_key, [
f"transformer.single_transformer_blocks.{block_num}.norm.linear"
])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain me this change?

Copy link
Author

@raulmosa raulmosa Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, here again, I forgot to remove the commented line, it is from the previous code snippet.
Using the same example shown in the previous comment, the original old_state_dict of a LoRA model where single blocks 1 to 4 are trained (only keys for double block 9 and single blocks are shown):

  • Double block keys example:
'double_blocks.9.processor.proj_lora1.down.weight', 'double_blocks.9.processor.proj_lora1.up.weight', 'double_blocks.9.processor.proj_lora2.down.weight', 'double_blocks.9.processor.proj_lora2.up.weight', 'double_blocks.9.processor.qkv_lora1.down.weight', 'double_blocks.9.processor.qkv_lora1.up.weight', 'double_blocks.9.processor.qkv_lora2.down.weight', 'double_blocks.9.processor.qkv_lora2.up.weight',
  • Single blocks keys example:
 'single_blocks.1.processor.proj_lora.down.weight', 'single_blocks.1.processor.proj_lora.up.weight', 'single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.proj_lora.down.weight', 'single_blocks.2.processor.proj_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.proj_lora.down.weight', 'single_blocks.3.processor.proj_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'

qkv_lora1 and qkv_lora2 are not presented in single blocks, the key is qkv_lora, then I've used the same logic and function used to handle double blocks, i.e, function handle_qkv used to update the new_state_dict and remove the keys from old_state_dict. Then, in the last part of the code:

# Since we already handle qkv above.
        if "qkv" not in old_key:
            new_state_dict[new_key] = old_state_dict.pop(old_key)

    if len(old_state_dict) > 0:
        raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

All "qkv" for double and single blocks are handled and ValueError is not raised.

@raulmosa
Copy link
Author

raulmosa commented Nov 12, 2024

Logs - note the single_blocks keys in both cases:

State dicts before adapting to diffusers, i.e., at the beginning of the function:

def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
    new_state_dict = {}
    orig_keys = list(old_state_dict.keys())
    print(old_state_dict)
   ...
dict_keys(['double_blocks.0.processor.proj_lora1.down.weight', 'double_blocks.0.processor.proj_lora1.up.weight', 'double_blocks.0.processor.proj_lora2.down.weight', 'double_blocks.0.processor.proj_lora2.up.weight', 'double_blocks.0.processor.qkv_lora1.down.weight', 'double_blocks.0.processor.qkv_lora1.up.weight', 'double_blocks.0.processor.qkv_lora2.down.weight', 'double_blocks.0.processor.qkv_lora2.up.weight',...  'single_blocks.4.processor.proj_lora.down.weight', 'single_blocks.4.processor.proj_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight'])

Now if we print the old_state_dict and new_state_dict at the end of the function _convert_xlabs_flux_lora_to_diffusers:

       ...        
        # Since we already handle qkv above.
        if "qkv" not in old_key:
            new_state_dict[new_key] = old_state_dict.pop(old_key)

   print(old_state_dict)
   print(new_state_dict) 
   if len(old_state_dict) > 0:
        raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

  • Original code before changes output:

    • old_state_dict is not updated and still keeps keys:
dict_keys(['single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight']
    • new_state_dict single blocks updated only:
'transformer.single_transformer_blocks.1.lora_A.weight', 'transformer.single_transformer_blocks.1.lora_B.weight', 'transformer.single_transformer_blocks.2.lora_A.weight', 'transformer.single_transformer_blocks.2.lora_B.weight', 'transformer.single_transformer_blocks.3.lora_A.weight', 'transformer.single_transformer_blocks.3.lora_B.weight', 'transformer.single_transformer_blocks.4.lora_A.weight', 'transformer.single_transformer_blocks.4.lora_B.weight'
    • Error raised:
  File "/home/raul/.pyenv/versions/xflux/lib/python3.10/site-packages/diffusers/loaders/lora_conversion_utils.py", line 659, in _convert_xlabs_flux_lora_to_diffusers
    raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
ValueError: `old_state_dict` should be at this point but has: ['single_blocks.1.processor.qkv_lora.down.weight', 'single_blocks.1.processor.qkv_lora.up.weight', 'single_blocks.2.processor.qkv_lora.down.weight', 'single_blocks.2.processor.qkv_lora.up.weight', 'single_blocks.3.processor.qkv_lora.down.weight', 'single_blocks.3.processor.qkv_lora.up.weight', 'single_blocks.4.processor.qkv_lora.down.weight', 'single_blocks.4.processor.qkv_lora.up.weight']
  • Output after code update:
    • old_state_dict now is empty and error is not raised:
dict_keys([])
    • new_state_dict single blocks look like:
'transformer.single_transformer_blocks.1.proj_out.lora_A.weight', 'transformer.single_transformer_blocks.1.proj_out.lora_B.weight', 'transformer.single_transformer_blocks.1.norm.linear.lora_A.weight', 'transformer.single_transformer_blocks.1.norm.linear.lora_B.weight', 'transformer.single_transformer_blocks.2.proj_out.lora_A.weight', 'transformer.single_transformer_blocks.2.proj_out.lora_B.weight', 'transformer.single_transformer_blocks.2.norm.linear.lora_A.weight', 'transformer.single_transformer_blocks.2.norm.linear.lora_B.weight', 'transformer.single_transformer_blocks.3.proj_out.lora_A.weight', 'transformer.single_transformer_blocks.3.proj_out.lora_B.weight', 'transformer.single_transformer_blocks.3.norm.linear.lora_A.weight', 'transformer.single_transformer_blocks.3.norm.linear.lora_B.weight', 'transformer.single_transformer_blocks.4.proj_out.lora_A.weight', 'transformer.single_transformer_blocks.4.proj_out.lora_B.weight', 'transformer.single_transformer_blocks.4.norm.linear.lora_A.weight', 'transformer.single_transformer_blocks.4.norm.linear.lora_B.weight'

@raulmosa
Copy link
Author

Test output

pytest tests/lora/test_lora_layers_flux.py
============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/raul/workspace/diffusers
configfile: pyproject.toml
plugins: requests-mock-1.10.0, timeout-2.3.1, xdist-3.6.1
collected 36 items

tests/lora/test_lora_layers_flux.py ......s.....s...s...............ssss [100%]

=============================== warnings summary ===============================
tests/lora/test_lora_layers_flux.py::FluxLoRATests::test_simple_inference_save_pretrained
  /home/raul/.pyenv/versions/3.10.14/envs/dev_diff/lib/python3.10/site-packages/transformers/integrations/peft.py:418: FutureWarning: The `active_adapter` method is deprecated and will be removed in a future version.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================== 29 passed, 7 skipped, 1 warning in 47.96s ===================

@sayakpaul
Copy link
Member

sayakpaul commented Nov 12, 2024

Sorry, I am quite confused now.

  1. Can you please edit your comment to show parts of the state dict because it's difficult to follow through long logs of state dict keys. It will be easier to compare what's happening that way.
  2. The test you ran is skipping things it seems. So, to be able to run this particular test you will need to do the following:
export RUN_SLOW=1
export RUN_NIGHTLY=1

Then comment out:

@unittest.skip("We cannot run inference on this model with the current CI hardware")

(this will not be required once #9845 is in)

And then run:

pytest tests/lora/test_lora_layers_flux.py::FluxLoRAIntegrationTests::test_flux_xlabs

Apologies for not making it clearer in my first comment.

@raulmosa
Copy link
Author

Hi @sayakpaul ,
Sorry about the long logs, I've updated my comments, please let me know if something is not clear.
Regarding the test, I've run again following your comment:

pytest tests/lora/test_lora_layers_flux.py::FluxLoRAIntegrationTests::test_flux_xlabs

============================= test session starts ==============================
platform linux -- Python 3.10.14, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/raul/workspace/diffusers
configfile: pyproject.toml
plugins: requests-mock-1.10.0, timeout-2.3.1, xdist-3.6.1
collected 1 item

tests/lora/test_lora_layers_flux.py .                                           [100%]

============================ 1 passed in 133.97s (0:02:13) ============================

Regarding the code snippet to reproduce the issue:

  • lora_model_path = "XLabs-AI/flslux-RealismLora" -> Model in HF without single blocks working and not raising any error
  • Using instead an example like: lora_model_path = "<PATH-LoRA-trained-Xlabs.safetensors>" -> if the model has been trained with single blocks it will raise the error. If necessary I can push a safetensors file example to reproduce it.

@sayakpaul
Copy link
Member

sayakpaul commented Nov 13, 2024

Thanks so much, I understand the issue now. Can you also add a test for this?

@raulmosa
Copy link
Author

I've uploaded a Flux LoRA model trained with Xlabs containing single blocks: salinasr/test_xlabs_flux_lora_with_singleblocks
Testing that the lora model is correctly loaded would check these changes:

    def test_flux_xlabs_load_lora_with_single_blocks(self):
        try:
            self.pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
        except Exception as e:
            self.fail(f"load_lora_weights raised an exception: {e}")

I've added this function right after test_flux_xlabs in test_lora_layers_flux.py
Let me know if this would be ok, if so, should I update the PR directly

@sayakpaul
Copy link
Member

Thanks, I think you could add a test case similar to:

def test_flux_xlabs(self):

WDYT?

@raulmosa
Copy link
Author

I've removed the comments unused (discussed before) and I've added the test as you suggested:

    def test_flux_xlabs_load_lora_with_single_blocks(self):
        self.pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks",
                                        weight_name="lora.safetensors")
        self.pipeline.fuse_lora()
        self.pipeline.unload_lora_weights()
        self.pipeline.enable_model_cpu_offload()

        prompt = "a wizard mouse playing chess"

        out = self.pipeline(
            prompt,
            num_inference_steps=self.num_inference_steps,
            guidance_scale=3.5,
            output_type="np",
            generator=torch.manual_seed(self.seed),
        ).images
        out_slice = out[0, -3:, -3:, -1].flatten()
        expected_slice = np.array([0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625])
        max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

        assert max_diff < 1e-3

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This works for me.

I just tested on our CI and the slices fail because of the hardware changes. This is fine IMO as I can always update it here: #9845

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants