Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8404d7b
Initial commit
UmerHA Mar 15, 2024
84125df
Implemented block lora
UmerHA Mar 16, 2024
7405aff
Finishing up
UmerHA Mar 16, 2024
5c19f18
Reverted unrelated changes made by make style
UmerHA Mar 16, 2024
769f42b
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 16, 2024
8908c90
Fixed typo
UmerHA Mar 16, 2024
d9c55a5
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 16, 2024
7e6ce83
Fixed bug + Made text_encoder_2 scalable
UmerHA Mar 16, 2024
3c841fc
Integrated some review feedback
UmerHA Mar 18, 2024
72b8752
Incorporated review feedback
UmerHA Mar 19, 2024
2247bcb
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 19, 2024
145c7f3
Fix tests
UmerHA Mar 19, 2024
8e26004
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 19, 2024
87e54b4
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 20, 2024
83ff34b
Made every module configurable
UmerHA Mar 20, 2024
5054f02
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 20, 2024
c2395fa
Adapter to new lora test structure
UmerHA Mar 21, 2024
624b2dd
Final cleanup
UmerHA Mar 21, 2024
578e974
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
0b32d64
Some more final fixes
UmerHA Mar 21, 2024
2b4aae6
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 21, 2024
38038b7
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
7411cab
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
3ed3ca5
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
df9df2e
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 22, 2024
24d376f
Make style, quality, fix-copies
UmerHA Mar 22, 2024
8fa6c25
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 22, 2024
7dfa8e3
Updated tutorial;Warning if scale/adapter mismatch
UmerHA Mar 22, 2024
9c6f613
floats are forwarded as-is; changed tutorial scale
UmerHA Mar 23, 2024
a469a4d
make style, quality, fix-copies
UmerHA Mar 23, 2024
957358b
Fixed typo in tutorial
UmerHA Mar 23, 2024
cb062b6
Moved some warnings into `lora_loader_utils.py`
UmerHA Mar 23, 2024
1e61dfb
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 23, 2024
a4a38df
Moved scale/lora mismatch warnings back
UmerHA Mar 24, 2024
9aa1479
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 27, 2024
625045a
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 28, 2024
2939e45
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 29, 2024
14fabf0
Integrated final review suggestions
UmerHA Mar 29, 2024
8500161
Empty commit to trigger CI
UmerHA Mar 29, 2024
74ce9bb
Reverted emoty commit to trigger CI
UmerHA Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ image

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

For fine-grained control on how much of the LoRA weights are used, use [`~loaders.LoraLoaderMixin.set_adapters`]. Here, you can define scale of any granularity up to per-transformer.
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # because "mid" is not given, all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
pipe.load_lora_weights("my_adapter", scales)
```
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
</Tip>

To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
Expand Down
69 changes: 64 additions & 5 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
text_encoder_weights: Optional[Union[float, List[float]]] = None,
):
"""
Sets the adapter layers for the text encoder.
Expand All @@ -986,6 +986,9 @@ def process_weights(adapter_names, weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)

weights = [{"text_model": w} if w is not None else {"text_model": 1.0} for w in weights]
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

return weights

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
Expand Down Expand Up @@ -1033,17 +1036,73 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
adapter_weights: Optional[Union[List[float], List[Dict]]] = None,
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

number = (float, int)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
has_2nd_text_encoder = hasattr(self, "text_encoder_2")
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

# Expand weights into a list, one entry per adapter
if adapter_weights is None or isinstance(adapter_weights, (number, dict)):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
adapter_weights = [adapter_weights] * len(adapter_names)

if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)

# Normalize into dicts
allowed_keys = ["text_encoder", "down", "mid", "up"]
if has_2nd_text_encoder:
allowed_keys.append("text_encoder_2")

def ensure_is_dict(weight):
if isinstance(weight, dict):
return weight
elif isinstance(weight, number):
return {key: weight for key in allowed_keys}
elif weight is None:
return {key: 1.0 for key in allowed_keys}
else:
raise RuntimeError(f"lora weight has wrong type {type(weight)}.")
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

UmerHA marked this conversation as resolved.
Show resolved Hide resolved
adapter_weights = [ensure_is_dict(weight) for weight in adapter_weights]
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

for weights in adapter_weights:
for k in weights.keys():
if k not in allowed_keys:
raise ValueError(
f"Got invalid key '{k}' in lora weight dict. Allowed keys are 'text_encoder', 'text_encoder_2', 'down', 'mid', 'up'."
)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_weights, text_encoder_weights = [], []
if has_2nd_text_encoder:
text_encoder_2_weights = []

for weights in adapter_weights:
unet_weight = {k: v for k, v in weights.items() if "text_encoder" not in k}
if len(unet_weight) == 0:
unet_weight = None
text_encoder_weight = weights.get("text_encoder", None)
if has_2nd_text_encoder:
text_encoder_2_weight = weights.get("text_encoder_2", None)

unet_weights.append(unet_weight)
text_encoder_weights.append(text_encoder_weight)
if has_2nd_text_encoder:
text_encoder_2_weights.append(text_encoder_2_weight)

unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, unet_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_weights)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

def disable_lora(self):
if not USE_PEFT_BACKEND:
Expand Down
109 changes: 108 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,100 @@ def _unfuse_lora_apply(self, module):
if isinstance(module, BaseTunerLayer):
module.unmerge()

def _expand_lora_scales_dict(
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
self, scales, blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int]
):
"""
Expand input into a weight dict with a weight per transformer
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

Parameters:
blocks_with_transformer (`Dict[str, int]`):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has

E.g. turns
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
{
'down': 2,
'mid': 3,
'up': {
'block_1': 4,
'block_2': [5, 6, 7]
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
}
}
into
{
'down.block_1.0': 2,
'down.block_1.1': 2,
'down.block_2.0': 2,
'down.block_2.1': 2,
'mid': 3,
'up.block_0.0': 4,
'up.block_0.1': 4,
'up.block_0.2': 4,
'up.block_1.0': 5,
'up.block_1.1': 6,
'up.block_1.2': 7,
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
"""
number = (float, int)

if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

if isinstance(scales, number):
scales = {o: scales for o in ["down", "mid", "up"]}

if "mid" not in scales:
scales["mid"] = 1

for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = 1

# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if isinstance(scales[updown], number):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}

# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
if isinstance(scales[updown][block], number):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
for tf_idx, value in enumerate(scales[updown][block]):
scales[f"{updown}.{block}.{tf_idx}"] = value

del scales[updown]

def layer_name(name):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
if name == "mid":
return "mid_block.attentions.0"

updown, block, attn = name.split(".")

updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
block = block.replace("block_", "")
attn = "attentions." + attn

return ".".join((updown, block, attn))

state_dict = self.state_dict()
for layer in scales.keys():
if not any(layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or has not attentions."
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
)

return {layer_name(name): weight for name, weight in scales.items()}

def set_adapters(
self,
adapter_names: Union[List[str], str],
Expand Down Expand Up @@ -599,14 +693,27 @@ def set_adapters(

if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
elif isinstance(weights, (float, dict)):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)

# Set missing value to default of 1.0
weights = [weight or 1.0 for weight in weights]
blocks_with_transformer = {
"down": [i for i, block in enumerate(self.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(self.up_blocks) if hasattr(block, "attentions")],
}
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
transformer_per_block = {"down": self.config.layers_per_block, "up": self.config.layers_per_block + 1}

weights = [
self._expand_lora_scales_dict(weight_for_adapter, blocks_with_transformer, transformer_per_block)
for weight_for_adapter in weights
]

set_weights_and_activate_adapters(self, adapter_names, weights)

def disable_lora(self):
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,22 @@ def delete_adapter_layers(model, adapter_name):
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer

def get_module_weight(weight_for_adapter, module_name):
for layer_name, weight_ in weight_for_adapter.items():
if layer_name in module_name:
return weight_
raise RuntimeError(f"No LoRA weight found for module {module_name}, which should never happen.")
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.set_scale(adapter_name, weight)
module.set_scale(adapter_name, get_module_weight(weight, module_name))

# set multiple active adapters
for module in model.modules():
Expand Down
Loading