Skip to content

Commit

Permalink
Implements Blockwise lora (huggingface#7352)
Browse files Browse the repository at this point in the history
* Initial commit

* Implemented block lora

- implemented block lora
- updated docs
- added tests

* Finishing up

* Reverted unrelated changes made by make style

* Fixed typo

* Fixed bug + Made text_encoder_2 scalable

* Integrated some review feedback

* Incorporated review feedback

* Fix tests

* Made every module configurable

* Adapter to new lora test structure

* Final cleanup

* Some more final fixes

- Included examples in `using_peft_for_inference.md`
- Added hint that only attns are scaled
- Removed NoneTypes
- Added test to check mismatching lens of adapter names / weights raise error

* Update using_peft_for_inference.md

* Update using_peft_for_inference.md

* Make style, quality, fix-copies

* Updated tutorial;Warning if scale/adapter mismatch

* floats are forwarded as-is; changed tutorial scale

* make style, quality, fix-copies

* Fixed typo in tutorial

* Moved some warnings into `lora_loader_utils.py`

* Moved scale/lora mismatch warnings back

* Integrated final review suggestions

* Empty commit to trigger CI

* Reverted emoty commit to trigger CI

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and noskill committed Apr 5, 2024
1 parent 14f6d2d commit 0a14b9d
Show file tree
Hide file tree
Showing 7 changed files with 553 additions and 21 deletions.
56 changes: 56 additions & 0 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,62 @@ image

![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)

### Customize adapters strength
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].

For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
```python
pipe.enable_lora() # enable lora again, after we disabled it above
prompt = "toy_face of a hacker with a hoodie, pixel art"
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-down](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_down.png)

Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-mid](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mid.png)

```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-up](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_up.png)

Looks cool!

This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
```python
adapter_weight_scales_toy = 0.5
adapter_weight_scales_pixel = {
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)

## Manage active adapters

You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
Expand Down
37 changes: 31 additions & 6 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,43 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>

<Tip>

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

</Tip>

To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:

```py
pipeline.unload_lora_weights()
```

### Adjust LoRA weight scale

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters("my_adapter", scales)
```

This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.

<Tip warning={true}>

Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.

</Tip>

### Kohya and TheLastBen

Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
Expand Down
84 changes: 75 additions & 9 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from pathlib import Path
Expand Down Expand Up @@ -985,7 +986,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Expand All @@ -1003,15 +1004,20 @@ def set_adapters_for_text_encoder(
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
# Expand weights into a list, one entry per adapter
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)

# Set None values to default of 1.0
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
weights = [w if w is not None else 1.0 for w in weights]

return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
Expand Down Expand Up @@ -1059,17 +1065,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

adapter_weights = copy.deepcopy(adapter_weights)

# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)

if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)

# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []

list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
all_adapters = {
adapter for adapters in list_adapters.values() for adapter in adapters
} # eg ["adapter1", "adapter2"]
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}

for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
unet_lora_weight = weights.pop("unet", None)
text_encoder_lora_weight = weights.pop("text_encoder", None)
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)

if len(weights) > 0:
raise ValueError(
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
)

if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning(
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
)

# warn if adapter doesn't have parts specified by adapter_weights
for part_weight, part_name in zip(
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
["uent", "text_encoder", "text_encoder_2"],
):
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
logger.warning(
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
)

else:
unet_lora_weight = weights
text_encoder_lora_weight = weights
text_encoder_2_lora_weight = weights

unet_lora_weights.append(unet_lora_weight)
text_encoder_lora_weights.append(text_encoder_lora_weight)
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)

unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, unet_lora_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)

def disable_lora(self):
if not USE_PEFT_BACKEND:
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers


Expand Down Expand Up @@ -564,7 +565,7 @@ def _unfuse_lora_apply(self, module):
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Expand Down Expand Up @@ -597,16 +598,23 @@ def set_adapters(

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
# Expand weights into a list, one entry per adapter
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)

# Set None values to default of 1.0
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
weights = [w if w is not None else 1.0 for w in weights]

# e.g. [{...}, 7] -> [{expanded dict...}, 7]
weights = _maybe_expand_lora_scales(self, weights)

set_weights_and_activate_adapters(self, adapter_names, weights)

def disable_lora(self):
Expand Down
Loading

0 comments on commit 0a14b9d

Please sign in to comment.