Skip to content

Commit

Permalink
Merge pull request #66 from bmaltais/dev
Browse files Browse the repository at this point in the history
v20.4.0
  • Loading branch information
bmaltais committed Jan 22, 2023
2 parents f7e8a80 + 2ca17f6 commit c0b4c9b
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 63 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,23 @@ Once you have created the LoRA network you can generate images via auto1111 by i

## Change history

* 2023/01/16 (v20.3.0)
* 2023/01/22 (v20.4.0):
- Add support for `network_alpha` under the Training tab and support for `--training_comment` under the Folders tab.
- Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe!
- Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 .
- The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version.
- LoRA with a large dimension (rank) seems to require a higher learning rate with ``alpha=1`` (e.g. 1e-3 for 128-dim, still investigating). 
- For generating images in Web UI, __the latest version of the extension ``sd-webui-additional-networks`` (v0.3.0 or later) is required for the models trained with this release or later.__
- Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev!
- Add more metadata such as dataset/reg image dirs, session ID, output name etc... See https://github.com/kohya-ss/sd-scripts/pull/77 for details. Thanks to space-nuko!
- __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
- Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension.

It seems that the Stable Diffusion web UI now supports image generation using the LoRA model learned in this repository.

Note: At this time, it appears that models learned with version 0.4.0 are not supported. If you want to use the generation function of the web UI, please continue to use version 0.3.2. Also, it seems that LoRA models for SD2.x are not supported.

* 2023/01/16 (v20.3.0):
- Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled.
- Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu!
- Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``.
Expand Down
19 changes: 8 additions & 11 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,6 @@ def __getattr__(self, item):
imported_module = importlib.import_module(network_module)

network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]

net_kwargs = {}
if args.network_args and i < len(args.network_args):
Expand All @@ -1992,22 +1991,22 @@ def __getattr__(self, item):
key, value = net_arg.split("=")
net_kwargs[key] = value

network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return

if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)

if os.path.splitext(network_weight)[1] == '.safetensors':
if model_util.is_safetensors(network_weight):
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")

network.load_weights(network_weight)
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return

network.apply_to(text_encoder, unet)

Expand Down Expand Up @@ -2518,16 +2517,14 @@ def process_batch(batch, highres_fix, highres_1st=False):
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
Expand Down
40 changes: 36 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import os
import random
import hashlib

from tqdm import tqdm
import torch
Expand Down Expand Up @@ -79,6 +80,11 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.token_padding_disabled = False
self.dataset_dirs_info = {}
self.reg_dataset_dirs_info = {}
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None

self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2

Expand Down Expand Up @@ -463,6 +469,8 @@ def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_toke
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
Expand Down Expand Up @@ -523,6 +531,7 @@ def load_dreambooth_dir(dir):
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images

Expand All @@ -539,6 +548,7 @@ def load_dreambooth_dir(dir):
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}

print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
Expand Down Expand Up @@ -611,6 +621,8 @@ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_to
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0

self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}

# check existence of all npz files
if not self.color_aug:
npz_any = False
Expand Down Expand Up @@ -653,6 +665,8 @@ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_to
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
Expand All @@ -665,6 +679,9 @@ def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_to
self.bucket_resos.sort()
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]

self.min_bucket_reso = min([min(reso) for reso in resos])
self.max_bucket_reso = max([max(reso) for reso in resos])

def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + '.npz'
Expand Down Expand Up @@ -749,9 +766,9 @@ def default(val, d):


def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()

file.seek(0x100000)
Expand All @@ -761,6 +778,18 @@ def model_hash(filename):
return 'NOFILE'


def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024

with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)

return hash_sha256.hexdigest()


# flash attention forwards and backwards

# https://arxiv.org/abs/2205.14135
Expand Down Expand Up @@ -1029,7 +1058,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
Expand All @@ -1048,8 +1078,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:

parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
parser.add_argument("--max_train_epochs", type=int, default=None,
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
Expand Down
28 changes: 26 additions & 2 deletions lora_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def save_configuration(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -175,6 +177,8 @@ def open_configuration(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -246,6 +250,8 @@ def train_model(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
Expand Down Expand Up @@ -358,6 +364,9 @@ def train_model(
run_cmd += f' --resolution={max_resolution}'
run_cmd += f' --output_dir="{output_dir}"'
run_cmd += f' --logging_dir="{logging_dir}"'
run_cmd += f' --network_alpha="{network_alpha}"'
if not training_comment == '':
run_cmd += f' --training_comment="{training_comment}"'
if not stop_text_encoder_training == 0:
run_cmd += (
f' --stop_text_encoder_training={stop_text_encoder_training}'
Expand Down Expand Up @@ -518,10 +527,15 @@ def lora_tab(
with gr.Row():
output_name = gr.Textbox(
label='Model output name',
placeholder='Name of the model to output',
placeholder='(Name of the model to output)',
value='last',
interactive=True,
)
training_comment = gr.Textbox(
label='Training comment',
placeholder='(Optional) Add training comment to be included in metadata',
interactive=True,
)
train_data_dir.change(
remove_doublequote,
inputs=[train_data_dir],
Expand Down Expand Up @@ -588,11 +602,19 @@ def lora_tab(
network_dim = gr.Slider(
minimum=1,
maximum=128,
label='Network Dimension',
label='Network Rank (Dimension)',
value=8,
step=1,
interactive=True,
)
network_alpha = gr.Slider(
minimum=1,
maximum=128,
label='Network Alpha',
value=1,
step=1,
interactive=True,
)
with gr.Row():
max_resolution = gr.Textbox(
label='Max resolution',
Expand Down Expand Up @@ -703,6 +725,8 @@ def lora_tab(
max_token_length,
max_train_epochs,
max_data_loader_n_workers,
network_alpha,
training_comment,
]

button_open_config.click(
Expand Down
7 changes: 4 additions & 3 deletions networks/check_lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ def main(file):

keys = list(sd.keys())
for key in keys:
if 'lora_up' in key:
if 'lora_up' in key or 'lora_down' in key:
values.append((key, sd[key]))
print(f"number of LoRA-up modules: {len(values)}")
print(f"number of LoRA modules: {len(values)}")

for key, value in values:
print(f"{key},{torch.mean(torch.abs(value))}")
value = value.to(torch.float32)
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit c0b4c9b

Please sign in to comment.