Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native Flux nf4 LoRA #3461

Merged
merged 7 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions extensions-builtin/Lora/lora_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
import bisect
import torch
from typing import Dict
from modules import shared

Expand Down Expand Up @@ -174,6 +175,8 @@ def diffusers(self, key):
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
# SegMoE begin
expert_key = key + "_experts_0"
expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None)
Expand Down Expand Up @@ -253,3 +256,200 @@ def match(match_list, regex_text):
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return key


# Taken from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/lora_conversion_utils.py
# Modified from 'lora_A' and 'lora_B' to 'lora_down' and 'lora_up'
# Added early exit
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")

# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

ait_sd[ait_key + ".lora_down.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_up.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up

def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]

# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank

# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]

# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
# if is_sparse:
# print(f"weight is sparse: {sds_key}")

# make ai-toolkit weight
ait_down_keys = [k + ".lora_down.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_up.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})

# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416

# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]

def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)

for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

if len(sds_sd) > 0:
return None

return ait_sd

return _convert_sd_scripts_to_ai_toolkit(state_dict)
10 changes: 10 additions & 0 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def __init__(self):
def apply(self):
if self.active or shared.opts.lora_force_diffusers:
return
try:
import bitsandbytes
self.Linear4bit_forward = patches.patch(__name__, bitsandbytes.nn.Linear4bit, 'forward', networks.network_Linear4bit_forward)
except:
pass
if "Model" in shared.opts.optimum_quanto_weights or "Text Encoder" in shared.opts.optimum_quanto_weights:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward) # pylint: disable=attribute-defined-outside-init
Expand All @@ -42,6 +47,11 @@ def apply(self):
def undo(self):
if not self.active or shared.opts.lora_force_diffusers:
return
try:
import bitsandbytes
self.Linear4bit_forward = patches.undo(__name__, bitsandbytes.nn.Linear4bit, 'forward')
except:
pass
if "Model" in shared.opts.optimum_quanto_weights or "Text Encoder" in shared.opts.optimum_quanto_weights:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128, attribute-defined-outside-init
Expand Down
9 changes: 5 additions & 4 deletions extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_module(self, weights, key, none_ok=False):
if weight is None and none_ok:
return None
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear"}
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
Expand Down Expand Up @@ -55,12 +55,13 @@ def create_module(self, weights, key, none_ok=False):
return module

def calc_updown(self, target): # pylint: disable=W0237
up = self.up_model.weight.to(target.device, dtype=target.dtype)
down = self.down_model.weight.to(target.device, dtype=target.dtype)
target_dtype = target.dtype if target.dtype != torch.uint8 else self.up_model.weight.dtype
up = self.up_model.weight.to(target.device, dtype=target_dtype)
down = self.down_model.weight.to(target.device, dtype=target_dtype)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(target.device, dtype=target.dtype)
mid = self.mid_model.weight.to(target.device, dtype=target_dtype)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
Expand Down
1 change: 0 additions & 1 deletion extensions-builtin/Lora/network_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'kandinsky',
'hunyuandit',
'auraflow',
'f1',
]

force_classes = [ # forced always
Expand Down
Loading