Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8404d7b
Initial commit
UmerHA Mar 15, 2024
84125df
Implemented block lora
UmerHA Mar 16, 2024
7405aff
Finishing up
UmerHA Mar 16, 2024
5c19f18
Reverted unrelated changes made by make style
UmerHA Mar 16, 2024
769f42b
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 16, 2024
8908c90
Fixed typo
UmerHA Mar 16, 2024
d9c55a5
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 16, 2024
7e6ce83
Fixed bug + Made text_encoder_2 scalable
UmerHA Mar 16, 2024
3c841fc
Integrated some review feedback
UmerHA Mar 18, 2024
72b8752
Incorporated review feedback
UmerHA Mar 19, 2024
2247bcb
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 19, 2024
145c7f3
Fix tests
UmerHA Mar 19, 2024
8e26004
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 19, 2024
87e54b4
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 20, 2024
83ff34b
Made every module configurable
UmerHA Mar 20, 2024
5054f02
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 20, 2024
c2395fa
Adapter to new lora test structure
UmerHA Mar 21, 2024
624b2dd
Final cleanup
UmerHA Mar 21, 2024
578e974
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
0b32d64
Some more final fixes
UmerHA Mar 21, 2024
2b4aae6
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 21, 2024
38038b7
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
7411cab
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
3ed3ca5
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
df9df2e
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 22, 2024
24d376f
Make style, quality, fix-copies
UmerHA Mar 22, 2024
8fa6c25
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 22, 2024
7dfa8e3
Updated tutorial;Warning if scale/adapter mismatch
UmerHA Mar 22, 2024
9c6f613
floats are forwarded as-is; changed tutorial scale
UmerHA Mar 23, 2024
a469a4d
make style, quality, fix-copies
UmerHA Mar 23, 2024
957358b
Fixed typo in tutorial
UmerHA Mar 23, 2024
cb062b6
Moved some warnings into `lora_loader_utils.py`
UmerHA Mar 23, 2024
1e61dfb
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 23, 2024
a4a38df
Moved scale/lora mismatch warnings back
UmerHA Mar 24, 2024
9aa1479
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 27, 2024
625045a
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 28, 2024
2939e45
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 29, 2024
14fabf0
Integrated final review suggestions
UmerHA Mar 29, 2024
8500161
Empty commit to trigger CI
UmerHA Mar 29, 2024
74ce9bb
Reverted emoty commit to trigger CI
UmerHA Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,62 @@ image

![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)

### Customize adapters strength
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

For example, here's how you can turn on the adapter for the `text_encoder` and `down` parts, but turn it off for the `mid` and `up` parts:
```python
pipe.enable_lora() # enable lora again, after we disabled it above
prompt = "toy_face of a hacker with a hoodie, pixel art"
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-down](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_down.png)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-mid](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mid.png)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-text-and-up](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_up.png)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

Looks cool!

This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
```python
adapter_weight_scales_toy = 0.5
adapter_weight_scales_pixel = {
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```

![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

## Manage active adapters

You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
Expand Down
45 changes: 39 additions & 6 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,51 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>

<Tip>

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

</Tip>

To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:

```py
pipeline.unload_lora_weights()
```

### Adjust LoRA weight scale
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters("my_adapter", scales)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```

This also works with multiple adapters:
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter_1")
pipe.load_lora_weights(..., adapter_name="my_adapter_2")
scales_1 = { ... }
scales_2 = { ... }
pipe.set_adapters(["my_adapter_1", "my_adapter_2"], [scales_1, scales_2])
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```

<Tip warning={true}>

Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRa has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

</Tip>

### Kohya and TheLastBen

Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
Expand Down
84 changes: 75 additions & 9 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from pathlib import Path
Expand Down Expand Up @@ -985,7 +986,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Expand All @@ -1003,15 +1004,20 @@ def set_adapters_for_text_encoder(
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
# Expand weights into a list, one entry per adapter
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)

# Set None values to default of 1.0
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
weights = [w if w is not None else 1.0 for w in weights]

return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
Expand Down Expand Up @@ -1059,17 +1065,77 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

adapter_weights = copy.deepcopy(adapter_weights)

# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)

if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)

# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []

UmerHA marked this conversation as resolved.
Show resolved Hide resolved
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
all_adapters = {
adapter for adapters in list_adapters.values() for adapter in adapters
} # eg ["adapter1", "adapter2"]
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}

for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
unet_lora_weight = weights.pop("unet", None)
text_encoder_lora_weight = weights.pop("text_encoder", None)
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)

if len(weights) > 0:
raise ValueError(
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning(
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
)

# warn if adapter doesn't have parts specified by adapter_weights
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
for part_weight, part_name in zip(
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
["uent", "text_encoder", "text_encoder_2"],
):
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
logger.warning(
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
)

else:
unet_lora_weight = weights
text_encoder_lora_weight = weights
text_encoder_2_lora_weight = weights

unet_lora_weights.append(unet_lora_weight)
text_encoder_lora_weights.append(text_encoder_lora_weight)
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)

unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, unet_lora_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)

def disable_lora(self):
if not USE_PEFT_BACKEND:
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .unet_loader_utils import maybe_expand_lora_scales
from .utils import AttnProcsLayers


Expand Down Expand Up @@ -564,7 +565,7 @@ def _unfuse_lora_apply(self, module):
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Expand Down Expand Up @@ -597,16 +598,23 @@ def set_adapters(

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
# Expand weights into a list, one entry per adapter
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)

# Set None values to default of 1.0
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
weights = [w if w is not None else 1.0 for w in weights]

# e.g. [{...}, 7] -> [{expanded dict...}, 7]
weights = maybe_expand_lora_scales(self, weights)

set_weights_and_activate_adapters(self, adapter_names, weights)

def disable_lora(self):
Expand Down
Loading
Loading