Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8404d7b
Initial commit
UmerHA Mar 15, 2024
84125df
Implemented block lora
UmerHA Mar 16, 2024
7405aff
Finishing up
UmerHA Mar 16, 2024
5c19f18
Reverted unrelated changes made by make style
UmerHA Mar 16, 2024
769f42b
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 16, 2024
8908c90
Fixed typo
UmerHA Mar 16, 2024
d9c55a5
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 16, 2024
7e6ce83
Fixed bug + Made text_encoder_2 scalable
UmerHA Mar 16, 2024
3c841fc
Integrated some review feedback
UmerHA Mar 18, 2024
72b8752
Incorporated review feedback
UmerHA Mar 19, 2024
2247bcb
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 19, 2024
145c7f3
Fix tests
UmerHA Mar 19, 2024
8e26004
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 19, 2024
87e54b4
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 20, 2024
83ff34b
Made every module configurable
UmerHA Mar 20, 2024
5054f02
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 20, 2024
c2395fa
Adapter to new lora test structure
UmerHA Mar 21, 2024
624b2dd
Final cleanup
UmerHA Mar 21, 2024
578e974
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
0b32d64
Some more final fixes
UmerHA Mar 21, 2024
2b4aae6
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 21, 2024
38038b7
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
7411cab
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
3ed3ca5
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
df9df2e
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 22, 2024
24d376f
Make style, quality, fix-copies
UmerHA Mar 22, 2024
8fa6c25
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 22, 2024
7dfa8e3
Updated tutorial;Warning if scale/adapter mismatch
UmerHA Mar 22, 2024
9c6f613
floats are forwarded as-is; changed tutorial scale
UmerHA Mar 23, 2024
a469a4d
make style, quality, fix-copies
UmerHA Mar 23, 2024
957358b
Fixed typo in tutorial
UmerHA Mar 23, 2024
cb062b6
Moved some warnings into `lora_loader_utils.py`
UmerHA Mar 23, 2024
1e61dfb
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 23, 2024
a4a38df
Moved scale/lora mismatch warnings back
UmerHA Mar 24, 2024
9aa1479
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 27, 2024
625045a
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 28, 2024
2939e45
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 29, 2024
14fabf0
Integrated final review suggestions
UmerHA Mar 29, 2024
8500161
Empty commit to trigger CI
UmerHA Mar 29, 2024
74ce9bb
Reverted emoty commit to trigger CI
UmerHA Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,35 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>

<Tip>

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

</Tip>

To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:

```py
pipeline.unload_lora_weights()
```

### Adjust LoRA weight scale
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying how much to scale the weights in each layer by.
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters("my_adapter", scales)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```

### Kohya and TheLastBen

Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
Expand Down
56 changes: 47 additions & 9 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
from pathlib import Path
from types import NoneType
from typing import Callable, Dict, List, Optional, Union

import safetensors
Expand Down Expand Up @@ -959,7 +960,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
text_encoder_weights: Optional[Union[float, List[float], List[NoneType]]] = None,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Sets the adapter layers for the text encoder.
Expand All @@ -977,15 +978,17 @@ def set_adapters_for_text_encoder(
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)

weights = [w or 1.0 for w in weights] # Set None values to default of 1.0
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
weights = [{"text_model": w} for w in weights]
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
Expand Down Expand Up @@ -1033,17 +1036,52 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)

if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)

# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_weights, text_encoder_weights, text_encoder_2_weights = [], [], []
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

UmerHA marked this conversation as resolved.
Show resolved Hide resolved
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
unet_weight = weights.pop("unet", None)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
text_encoder_weight = weights.pop("text_encoder", None)
text_encoder_2_weight = weights.pop("text_encoder_2", None)

if len(weights) >0:
raise ValueError(f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}.")

if text_encoder_2_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning("Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2.")
else:
unet_weight = weights
text_encoder_weight = weights
text_encoder_2_weight = weights

unet_weights.append(unet_weight)
text_encoder_weights.append(text_encoder_weight)
text_encoder_2_weights.append(text_encoder_2_weight)

unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, unet_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_weights)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

def disable_lora(self):
if not USE_PEFT_BACKEND:
Expand Down
122 changes: 118 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import NoneType
from typing import Callable, Dict, List, Optional, Union

import safetensors
Expand Down Expand Up @@ -561,10 +562,112 @@ def _unfuse_lora_apply(self, module):
if isinstance(module, BaseTunerLayer):
module.unmerge()

def _expand_lora_scales_dict(
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
self, scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int]
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.

Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`Dict[str, int]`):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has

E.g. turns
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
scales = {
'down': 2,
'mid': 3,
'up': {
'block_0': 4,
'block_1': [5, 6, 7]
}
}
blocks_with_transformer = {
'down': [1,2],
'up': [0,1]
}
transformer_per_block = {
'down': 2,
'up': 3
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
into
{
'down.block_1.0': 2,
'down.block_1.1': 2,
'down.block_2.0': 2,
'down.block_2.1': 2,
'mid': 3,
'up.block_0.0': 4,
'up.block_0.1': 4,
'up.block_0.2': 4,
'up.block_1.0': 5,
'up.block_1.1': 6,
'up.block_1.2': 7,
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
"""
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

if not isinstance(scales, dict):
scales = {o: scales for o in ["down", "mid", "up"]}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

if "mid" not in scales:
scales["mid"] = 1

for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = 1

# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}

# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
if not isinstance(scales[updown][block], dict):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
for tf_idx, value in enumerate(scales[updown][block]):
scales[f"{updown}.{block}.{tf_idx}"] = value

del scales[updown]

def layer_name(name):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
if name == "mid":
return "mid_block.attentions.0"

updown, block, attn = name.split(".")

updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
block = block.replace("block_", "")
attn = "attentions." + attn

return ".".join((updown, block, attn))

state_dict = self.state_dict()
for layer in scales.keys():
if not any(layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
)

return {layer_name(name): weight for name, weight in scales.items()}

def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
weights: Optional[Union[float, Dict, List[float], List[Dict], List[NoneType]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Expand Down Expand Up @@ -597,16 +700,27 @@ def set_adapters(

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
# Expand weights into a list, one entry per adapter
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)

weights = [weight or 1.0 for weight in weights] # Set None values to default of 1.0
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
blocks_with_transformer = {
"down": [i for i, block in enumerate(self.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(self.up_blocks) if hasattr(block, "attentions")],
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
transformer_per_block = {"down": self.config.layers_per_block, "up": self.config.layers_per_block + 1}

weights = [
self._expand_lora_scales_dict(weight_for_adapter, blocks_with_transformer, transformer_per_block)
for weight_for_adapter in weights
]

set_weights_and_activate_adapters(self, adapter_names, weights)

def disable_lora(self):
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,22 @@ def delete_adapter_layers(model, adapter_name):
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer

def get_module_weight(weight_for_adapter, module_name):
for layer_name, weight_ in weight_for_adapter.items():
if layer_name in module_name:
return weight_
raise RuntimeError(f"No LoRA weight found for module {module_name}, which should never happen.")
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.set_scale(adapter_name, weight)
module.set_scale(adapter_name, get_module_weight(weight, module_name))

# set multiple active adapters
for module in model.modules():
Expand Down
Loading
Loading