Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resize lora add new rank for conv #1102

Merged
merged 3 commits into from
Feb 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 66 additions & 89 deletions networks/resize_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo

import os
import argparse
import torch
from safetensors.torch import load_file, save_file, safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import train_util, model_util
from library import train_util
import numpy as np
from library.utils import setup_logging

Expand Down Expand Up @@ -35,7 +36,6 @@ def load_state_dict(file_name, dtype):

return sd, metadata


def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
Expand All @@ -47,7 +47,6 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
else:
torch.save(model, file_name)


# Indexing functions


Expand All @@ -62,19 +61,19 @@ def index_sv_cumulative(S, target):

def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
index = max(1, min(index, len(S)-1))

return index


def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv / target
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
index = max(1, min(index, len(S)-1))

return index

Expand Down Expand Up @@ -170,10 +169,10 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):

if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale * new_rank)
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale * new_rank)
new_alpha = float(scale*new_rank)

# Calculate resize info
s_sum = torch.sum(torch.abs(S))
Expand All @@ -193,29 +192,27 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
return param_dict


def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []

# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim

scale = network_alpha / network_dim
scale = network_alpha/network_dim

if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
)
logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")

lora_down_weight = None
lora_up_weight = None
Expand All @@ -227,56 +224,54 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
if 'lora_down' in key:
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue

# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)

weights_loaded = lora_down_weight is not None and lora_up_weight is not None
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)

if weights_loaded:

conv2d = len(lora_down_weight.size()) == 4
conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
scale = lora_alpha/lora_down_weight.size()[0]

if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)

if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
max_ratio = param_dict['max_ratio']
sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))

verbose_str += f"{block_down_name:75} | "
verbose_str += (
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
)
verbose_str += f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"

if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += f"\n"
verbose_str += "\n"

new_alpha = param_dict["new_alpha"]
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)

block_down_name = None
block_up_name = None
Expand All @@ -294,27 +289,30 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn


def resize(args):
if args.save_to is None or not (
args.save_to.endswith(".ckpt")
or args.save_to.endswith(".pt")
or args.save_to.endswith(".pth")
or args.save_to.endswith(".safetensors")
):
if (
args.save_to is None or
not (args.save_to.endswith('.ckpt') or
args.save_to.endswith('.pt') or
args.save_to.endswith('.pth') or
args.save_to.endswith('.safetensors'))
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")

args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank

def str_to_dtype(p):
if p == "float":
if p == 'float':
return torch.float
if p == "fp16":
if p == 'fp16':
return torch.float16
if p == "bf16":
if p == 'bf16':
return torch.bfloat16
return None

if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")

merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
Expand All @@ -323,9 +321,7 @@ def str_to_dtype(p):
lora_sd, metadata = load_state_dict(args.model, merge_dtype)

logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
)
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)

# update metadata
if metadata is None:
Expand All @@ -338,11 +334,9 @@ def str_to_dtype(p):
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = (
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
)
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
Expand All @@ -355,45 +349,28 @@ def str_to_dtype(p):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
)
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
)
parser.add_argument(
"--dynamic_method",
type=str,
default=None,
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--new_conv_rank", type=int, default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")

return parser


if __name__ == "__main__":
if __name__ == '__main__':
parser = setup_parser()

args = parser.parse_args()
Expand Down