Skip to content

Commit

Permalink
Merge pull request #3461 from AI-Casanova/flux-lora
Browse files Browse the repository at this point in the history
Native Flux nf4 LoRA
  • Loading branch information
vladmandic authored Sep 30, 2024
2 parents 492ee38 + fce431b commit e6ff17d
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 23 deletions.
200 changes: 200 additions & 0 deletions extensions-builtin/Lora/lora_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import bisect
import torch
from typing import Dict
from modules import shared

Expand Down Expand Up @@ -174,6 +175,8 @@ def diffusers(self, key):
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
# SegMoE begin
expert_key = key + "_experts_0"
expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None)
Expand Down Expand Up @@ -253,3 +256,200 @@ def match(match_list, regex_text):
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return key


# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up'
# Added early exit
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")

# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up

def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]

# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank

# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]

# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
# if is_sparse:
# print(f"weight is sparse: {sds_key}")

# make ai-toolkit weight
ait_down_keys = [k + ".lora_down.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_up.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})

# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416

# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]

def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)

for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

if len(sds_sd) > 0:
return None

return ait_sd

return _convert_sd_scripts_to_ai_toolkit(state_dict)
10 changes: 10 additions & 0 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def __init__(self):
def apply(self):
if self.active or shared.opts.lora_force_diffusers:
return
try:
import bitsandbytes
self.Linear4bit_forward = patches.patch(__name__, bitsandbytes.nn.Linear4bit, 'forward', networks.network_Linear4bit_forward)
except:
pass
if "Model" in shared.opts.optimum_quanto_weights or "Text Encoder" in shared.opts.optimum_quanto_weights:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward) # pylint: disable=attribute-defined-outside-init
Expand All @@ -42,6 +47,11 @@ def apply(self):
def undo(self):
if not self.active or shared.opts.lora_force_diffusers:
return
try:
import bitsandbytes
self.Linear4bit_forward = patches.undo(__name__, bitsandbytes.nn.Linear4bit, 'forward')
except:
pass
if "Model" in shared.opts.optimum_quanto_weights or "Text Encoder" in shared.opts.optimum_quanto_weights:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128, attribute-defined-outside-init
Expand Down
9 changes: 5 additions & 4 deletions extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_module(self, weights, key, none_ok=False):
if weight is None and none_ok:
return None
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear"}
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
Expand Down Expand Up @@ -55,12 +55,13 @@ def create_module(self, weights, key, none_ok=False):
return module

def calc_updown(self, target): # pylint: disable=W0237
up = self.up_model.weight.to(target.device, dtype=target.dtype)
down = self.down_model.weight.to(target.device, dtype=target.dtype)
target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype
up = self.up_model.weight.to(target.device, dtype=target_dtype)
down = self.down_model.weight.to(target.device, dtype=target_dtype)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(target.device, dtype=target.dtype)
mid = self.mid_model.weight.to(target.device, dtype=target_dtype)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
Expand Down
1 change: 0 additions & 1 deletion extensions-builtin/Lora/network_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'kandinsky',
'hunyuandit',
'auraflow',
'f1',
]

force_classes = [ # forced always
Expand Down
Loading

0 comments on commit e6ff17d

Please sign in to comment.