Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged

Conversation

UmerHA
Copy link
Contributor

@UmerHA UmerHA commented Mar 16, 2024

What does this PR do?

Allows setting LoRA weights more granularly, up to per-transformer block.

Fixes #7231

Example usage:

pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter") 
scales = {
    "text_encoder": 0.5,
    "unet": {
        "down": 0.9,  # all transformers in the down-part will use scale 0.9
        # "mid"  # because "mid" is not given, all transformers in the mid part will use the default scale 1.0
        "up": {
            "block_0": 0.6,  # all 3 transformers in the 0th block in the up-part will use scale 0.6
            "block_1": [0.4, 0.8, 1.0],  # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
        }
    }
}
pipe.set_adapters("my_adapter", scales)

Before submitting

Who can review?

Core library:

@dain5832: As you're the author of the solved issue, I invite you to also test this.

@UmerHA UmerHA changed the title 7231 blockwise lora Blockwise lora Mar 16, 2024
@UmerHA UmerHA changed the title Blockwise lora Implements Blockwise lora Mar 16, 2024
@asomoza
Copy link
Member

asomoza commented Mar 16, 2024

this PR is really nice, I like to be able to have this kind of control over the LoRAs. Just by playing with it I found a way to break it though. If you do this:

scales = {
    "text_encoder": 0.5,
}
pipeline.set_adapters(["test"], adapter_weights=[scales])

it throws:

unet.py", line 611, in _expand_lora_scales_dict
    if "mid" not in scales:
TypeError: argument of type 'NoneType' is not iterable

Also I would like to request if you can add to have control over the second text encoder in SDXL but if it's rare edge case, it doesn't matter, I can do it on my side after.

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 16, 2024

this PR is really nice

❤️

Just by playing with it I found a way to break it though.

fixed + expanded tests

Also I would like to request if you can add to have control over the second text encoder in SDXL

added


Let me know if there are any other issues!

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PR has a very nice start. I left some initial comments. We'll need to make sure to add rigorous checks on the inputs to prevent anything unexpected that we can foresee from our experience.

docs/source/en/using-diffusers/loading_adapters.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/loading_adapters.md Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 18, 2024

Re mid block: We can pass a "mid" key in the scales dict. It's commented out to show not all keys need to be provided. But the mid block only needs 1 scale number, because it is only 1 block and has only 1 layer.

In principal, mid blocks could have more layers, but UNet2DConditionModel doesn't pass num_layers to the mid block, so it always defaults to 1.

Re convs: The original implementation doesn't implements scaling the conv layers.


  • Would you also implement scaling the conv layers?
  • Would you also implement scaling the resnets? Currently, only the transformer are scaled.

Let me know, then I'll adapt the code.

@sayakpaul

@sayakpaul
Copy link
Member

Re mid block: We can pass a "mid" key in the scales dict. It's commented out to show not all keys need to be provided. But the mid block only needs 1 scale number, because it is only 1 block and has only 1 layer.

I still don't see why we shouldn't make it configurable as well actually. Even if it's just a single block there might be effects if the users can control its contributions. So, I would vote for the ability to configure it too.

Would you also implement scaling the conv layers?
Would you also implement scaling the resnets? Currently, only the transformer are scaled.

All of them should configurable IMO if we were to support this feature otherwise the design might feel very convoluted. But I would like to also take opinions from @BenjaminBossan and @yiyixuxu here.

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 18, 2024

I still don't see why we shouldn't make it configurable as well actually. Even if it's just a single block there might be effects if the users can control its contributions. So, I would vote for the ability to configure it too.

I think we're miscommunicating. With the current PR, You CAN already configure the mid block:

pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter") 
scales = { "mid" : 0.5 }
pipe.set_adapters("my_adapter", scales)

However, you can only pass a single number to it, because it has in total only 1 transformer layer in it. Unlike the up / down part, which have multiple blocks and multiple layers per block. That's why you can pass a Dict (per block) or Dict[List](per block, per layer) for them.

@sayakpaul
Copy link
Member

However, you can only pass a single number to it, because it has in total only 1 transformer layer in it. Unlike the up / down part, which have multiple blocks and multiple layers per block. That's why you can pass a Dict (per block) or Dict[List](per block, per layer) for them.

Ah makes sense then.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this feature. I don't have sufficient knowledge to judge if this covers all edge cases, so I'll leave that to the experts. Still added a few comments.

src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, great work! 👏

docs/source/en/using-diffusers/loading_adapters.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/loading_adapters.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! thanks for your PR
I did first round of review - let me know if it makes sense

Thanks
YiYi

src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member

@UmerHA would be helpful to resolve the comments that you have already addressed. But do feel free to keep them open if you're unsure.

Will give this a thorough look tomorrow. Making a reminder.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, nice tests. I just have some smaller comments, please check.

All in all, I don't know enough about diffusers to give a final approval, so I'll leave that to the experts.

docs/source/en/using-diffusers/loading_adapters.md Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/lora.py Outdated Show resolved Hide resolved
src/diffusers/loaders/unet.py Outdated Show resolved Hide resolved
@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 24, 2024

would something like this work?

With a small adjustment, yes. invert_list_adapters needs to be created a bit differently, because self.get_list_adapters() is not Dict[str, str] but Dict[str, List[str]].

This works:

list_adapters = self.get_list_adapters()  # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
all_adapters = {
    adapter for adapters in list_adapters.values() for adapter in adapters
}  # eg ["adapter1", "adapter2"]
invert_list_adapters = {
    adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
    for adapter in all_adapters
}  # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}

I have adjusted to code accordingly, and renamed lora_loaders_util.py back to unet_loaders_util.py

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 28, 2024

@sayakpaul gentle ping, as this has been stale for a week

@sayakpaul
Copy link
Member

Will do a thorough review tomorrow first thing. Thanks for your patience.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a last couple of comments. Looking really nice. Going to run some low tests to ensure we're not breaking anything here.

Really amazing work!

@sayakpaul
Copy link
Member

LoRA SD slow tests are passing which should be more than enough to confirm the PR isn't backward breaking. Let's address the last comments and ship this beast :shipit:

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 29, 2024

@sayakpaul Everything is addressed! Iiuc, you now have to accept the PR for the uploaded images (https://huggingface.co/datasets/huggingface/documentation-images/discussions/310) and then we're done :shipit::shipit::shipit:

@sayakpaul
Copy link
Member

Already merged.

@asomoza
Copy link
Member

asomoza commented Mar 29, 2024

really nice, when I have the time I'll play a lot with this and maybe post my results so people can understand better what can they do with this level of control over the LoRAs

@sayakpaul
Copy link
Member

@UmerHA could you push an empty commit to start the CI? And please ping me once you do.

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 29, 2024

@sayakpaul empty commit pushed

@sayakpaul
Copy link
Member

Thanks @UmerHA!

Will wait for the CI to run and will then ship.

To help promote this amazing feature, I welcome you to open a detailed thread here: https://github.com/huggingface/diffusers/discussions.

Additionally, if you're on HF Discord, please let me know your username so that we can give you a shoutout. If you're not, you can join via this link: https://hf.co/join.discord.

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 29, 2024

Thanks @UmerHA!

❤️

To help promote this amazing feature, I welcome you to open a detailed thread here: https://github.com/huggingface/diffusers/discussions.

Will do this evening/tomorrow ✅

Additionally, if you're on HF Discord, please let me know your username so that we can give you a shoutout. If you're not, you can join via this link: https://hf.co/join.discord.

Thanks :) My discord is umerha. But my primary channel is Twitter where I am @UmerHAdil. If possible, I'd highly appreciate a shootout over there 🙏🏽

@sayakpaul sayakpaul merged commit 0302446 into huggingface:main Mar 29, 2024
15 checks passed
@UmerHA UmerHA deleted the 7231-blockwise-lora branch March 29, 2024 18:29
@asomoza
Copy link
Member

asomoza commented Mar 29, 2024

@UmerHA I'll leave this here because I can't do a PR right now for a quick fix.

Just updated my main and ran the test I did before, there a small typo uent that throws a warning:

Lora weight dict for adapter 'lora' contains uent, but this will be ignored because lora does not contain weights for uent. Valid parts for lora are: ['text_encoder', 'text_encoder_2', 'unet'].

@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 29, 2024

there a small typo uent that throws a warning:

@asomoza thanks! have made a fix PR here

sayakpaul pushed a commit that referenced this pull request Mar 30, 2024
@UmerHA
Copy link
Contributor Author

UmerHA commented Mar 30, 2024

To help promote this amazing feature, I welcome you to open a detailed thread here: https://github.com/huggingface/diffusers/discussions.

Will do this evening/tomorrow ✅

Done ✅

noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
* Initial commit

* Implemented block lora

- implemented block lora
- updated docs
- added tests

* Finishing up

* Reverted unrelated changes made by make style

* Fixed typo

* Fixed bug + Made text_encoder_2 scalable

* Integrated some review feedback

* Incorporated review feedback

* Fix tests

* Made every module configurable

* Adapter to new lora test structure

* Final cleanup

* Some more final fixes

- Included examples in `using_peft_for_inference.md`
- Added hint that only attns are scaled
- Removed NoneTypes
- Added test to check mismatching lens of adapter names / weights raise error

* Update using_peft_for_inference.md

* Update using_peft_for_inference.md

* Make style, quality, fix-copies

* Updated tutorial;Warning if scale/adapter mismatch

* floats are forwarded as-is; changed tutorial scale

* make style, quality, fix-copies

* Fixed typo in tutorial

* Moved some warnings into `lora_loader_utils.py`

* Moved scale/lora mismatch warnings back

* Integrated final review suggestions

* Empty commit to trigger CI

* Reverted emoty commit to trigger CI

---------

Co-authored-by: Sayak Paul <[email protected]>
noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
@asomoza asomoza mentioned this pull request Apr 5, 2024
XSE42 added a commit to XSE42/diffusers3d that referenced this pull request Apr 30, 2024
diffusers commit 0302446
    Implements Blockwise lora (huggingface/diffusers#7352)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Applying Blockwise Weights to LoRA
7 participants