Skip to content

Commit

Permalink
Merge pull request bmaltais#43 from kohya-ss/dev
Browse files Browse the repository at this point in the history
Approximate difference of two models with LoRA, support multiple modules in generating
  • Loading branch information
kohya-ss authored Jan 6, 2023
2 parents d62725b + 39a0293 commit 54928fa
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 70 deletions.
117 changes: 52 additions & 65 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,3 @@
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...

# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
# v5: fix clip_sample=True for scheduler, add VGG guidance
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
# v10: fix app crashes when different image size in prompts

# Copyright 2022 kohya_ss @kohya_ss
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# license of included scripts:

# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE

# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE

"""
VGG(
(features): Sequential(
Expand Down Expand Up @@ -517,7 +482,7 @@ def __init__(
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)

# region xformersとか使う部分:独自に書き換えるので関係なし
# region xformersとか使う部分:独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
Expand Down Expand Up @@ -1982,26 +1947,42 @@ def __getattr__(self, item):
vgg16_model.to(dtype).to(device)

# networkを組み込む
if args.network_module is not None:
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"

print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)

network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
if network is None:
return

print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)

network.apply_to(text_encoder, unet)

if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if args.network_module:
networks = []
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]

net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value

network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network.load_weights(network_weight)

network.apply_to(text_encoder, unet)

if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)

networks.append(network)
else:
network = None
networks = []

if args.opt_channels_last:
print(f"set optimizing: channels last")
Expand All @@ -2010,8 +1991,9 @@ def __getattr__(self, item):
unet.to(memory_format=torch.channels_last)
if clip_model is not None:
clip_model.to(memory_format=torch.channels_last)
if network is not None:
network.to(memory_format=torch.channels_last)
if networks:
for network in networks:
network.to(memory_format=torch.channels_last)
if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last)

Expand Down Expand Up @@ -2053,7 +2035,7 @@ def load_images(path):
print(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)

return images

def resize_images(imgs, size):
Expand Down Expand Up @@ -2481,19 +2463,24 @@ def process_batch(batch, highres_fix, highres_1st=False):
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument("--seed", type=int, default=None,
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
parser.add_argument("--iter_same_seed", action='store_true',
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None,
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
Expand Down
158 changes: 158 additions & 0 deletions networks/extract_lora_from_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# extract approximating LoRA by svd from two SD models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!

import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora


CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6


def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)

if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)


def svd(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None

save_dtype = str_to_dtype(args.save_precision)

print(f"loading SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)

# create LoRA network to extract weights
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "

# get diffs
diffs = {}
text_encoder_different = False
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight

# Text Encoder might be same
if torch.max(torch.abs(diff)) > MIN_DIFF:
text_encoder_different = True

diff = diff.float()
diffs[lora_name] = diff

if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}

for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()

if args.device:
diff = diff.to(args.device)

diffs[lora_name] = diff

# make LoRA with svd
print("calculating by svd")
rank = args.dim
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()

U, S, Vh = torch.linalg.svd(mat)

U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)

Vh = Vh[:rank, :]

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)

lora_weights[lora_name] = (U, Vh)

# make state dict for LoRA
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
lora_sd = lora_network_o.state_dict()
print(f"LoRA has {len(lora_sd)} weights.")

for key in list(lora_sd.keys()):
lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1

weights = lora_weights[lora_name][i]
# print(key, i, weights.size(), lora_sd[key].size())
if len(lora_sd[key].size()) == 4:
weights = weights.unsqueeze(2).unsqueeze(3)

assert weights.size() == lora_sd[key].size()
lora_sd[key] = weights

# load state dict to LoRA and save it
info = lora_network_o.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")

dir_name = os.path.dirname(args.save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)

lora_network_o.save_weights(args.save_to, save_dtype)
print(f"LoRA weights are saved to: {args.save_to}")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
parser.add_argument("--model_org", type=str, default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
parser.add_argument("--model_tuned", type=str, default=None,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)")
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")

args = parser.parse_args()
svd(args)
42 changes: 37 additions & 5 deletions train_network_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。

WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルに、このリポジトリ内のスクリプトであらかじめマージしておく必要があります。マージ後のモデルファイルはLoRAの学習結果が反映されたものになります。

なお当リポジトリ内の画像生成スクリプトで生成する場合はマージ不要です。
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。

## 学習方法

Expand All @@ -24,7 +22,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正

### DreamBoothの手法を用いる場合

note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) を参照してデータを用意してください。
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。

学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。

Expand Down Expand Up @@ -110,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt

### 複数のLoRAのモデルをマージする

結局のところSDモデルにマージしないと推論できないのであまり使い道はないかもしれません。ただ、複数のLoRAモデルをひとつずつSDモデルにマージしていく場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。

たとえば以下のようなコマンドラインになります。

Expand Down Expand Up @@ -144,6 +142,40 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim

--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。

## 二つのモデルの差分からLoRAモデルを作成する

[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。

二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。

### スクリプトの実行方法

以下のように指定してください。
```
python networks\extract_lora_from_models.py --model_org base-model.ckpt
--model_tuned fine-tuned-model.ckpt
--save_to lora-weights.safetensors --dim 4
```

--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。

--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。

--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。

生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。

Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。

### その他のオプション

- --v2
- v2.xのStable Diffusionモデルを使う場合に指定してください。
- --device
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
- --save_precision
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。

## 追加情報

### cloneofsimo氏のリポジトリとの違い
Expand Down

0 comments on commit 54928fa

Please sign in to comment.