From 5dc2a0d3fd1a0cccf653aebf00ae17711f221008 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 30 Oct 2023 19:55:30 +0800 Subject: [PATCH 01/15] Add custom seperator --- library/config_util.py | 1 + library/train_util.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index e8e0fda7c..af4eedaa9 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -51,6 +51,7 @@ class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False + caption_seperator: str = ',', keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False diff --git a/library/train_util.py b/library/train_util.py index 51610e700..c04ad9a9b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -341,6 +341,7 @@ def __init__( image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, + caption_seperator: str, keep_tokens: int, color_aug: bool, flip_aug: bool, @@ -357,6 +358,7 @@ def __init__( self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption + self.caption_seperator = caption_seperator self.keep_tokens = keep_tokens self.color_aug = color_aug self.flip_aug = flip_aug @@ -383,6 +385,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_seperator: str, keep_tokens, color_aug, flip_aug, @@ -402,6 +405,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -435,6 +439,7 @@ def __init__( metadata_file: str, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -454,6 +459,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -484,6 +490,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -503,6 +510,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, + caption_seperator, keep_tokens, color_aug, flip_aug, @@ -638,7 +646,7 @@ def process_caption(self, subset: BaseSubset, caption): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(",")] + tokens = [t.strip() for t in caption.strip().split(subset.caption_seperator)] if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: @@ -3091,7 +3099,10 @@ def add_dataset_arguments( # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする" + "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" + ) + parser.add_argument( + "--caption_seperator", type=str, default=",", help="seperator for caption / captionの区切り文字" ) parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" From 583e2b2d0174e9a8bebcabfa178295e2980d334c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 30 Oct 2023 20:02:04 +0800 Subject: [PATCH 02/15] Fix typo --- library/train_util.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c04ad9a9b..35391e800 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -341,7 +341,7 @@ def __init__( image_dir: Optional[str], num_repeats: int, shuffle_caption: bool, - caption_seperator: str, + caption_separator: str, keep_tokens: int, color_aug: bool, flip_aug: bool, @@ -358,7 +358,7 @@ def __init__( self.image_dir = image_dir self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption - self.caption_seperator = caption_seperator + self.caption_separator = caption_separator self.keep_tokens = keep_tokens self.color_aug = color_aug self.flip_aug = flip_aug @@ -385,7 +385,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, - caption_seperator: str, + caption_separator: str, keep_tokens, color_aug, flip_aug, @@ -405,7 +405,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, - caption_seperator, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -439,7 +439,7 @@ def __init__( metadata_file: str, num_repeats, shuffle_caption, - caption_seperator, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -459,7 +459,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, - caption_seperator, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -490,7 +490,7 @@ def __init__( caption_extension: str, num_repeats, shuffle_caption, - caption_seperator, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -510,7 +510,7 @@ def __init__( image_dir, num_repeats, shuffle_caption, - caption_seperator, + caption_separator, keep_tokens, color_aug, flip_aug, @@ -646,7 +646,7 @@ def process_caption(self, subset: BaseSubset, caption): caption = "" else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: - tokens = [t.strip() for t in caption.strip().split(subset.caption_seperator)] + tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)] if subset.token_warmup_step < 1: # 初回に上書きする subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) if subset.token_warmup_step and self.current_step < subset.token_warmup_step: @@ -3102,7 +3102,7 @@ def add_dataset_arguments( "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" ) parser.add_argument( - "--caption_seperator", type=str, default=",", help="seperator for caption / captionの区切り文字" + "--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字" ) parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" From 489b728dbc7f85e22ca5e0fe4e7c91e2fb56c5f9 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 30 Oct 2023 20:19:51 +0800 Subject: [PATCH 03/15] Fix typo again --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index af4eedaa9..ab90fb63b 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -51,7 +51,7 @@ class BaseSubsetParams: image_dir: Optional[str] = None num_repeats: int = 1 shuffle_caption: bool = False - caption_seperator: str = ',', + caption_separator: str = ',', keep_tokens: int = 0 color_aug: bool = False flip_aug: bool = False From 6b3148fd3fb64e41aa29fc1759ebfab3a4504d45 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Tue, 7 Nov 2023 23:02:25 +0100 Subject: [PATCH 04/15] Fix min-snr-gamma for v-prediction and ZSNR. This fixes min-snr for vpred+zsnr by dividing directly by SNR+1. The old implementation did it in two steps: (min-snr/snr) * (snr/(snr+1)), which causes division by zero when combined with --zero_terminal_snr --- fine_tune.py | 2 +- library/custom_train_functions.py | 9 ++++++--- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 14 insertions(+), 11 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 52e84c43f..b07876776 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -355,7 +355,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 28b625d30..e0a026dae 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -57,10 +57,13 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) + if v_prediction: + snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device) + else: + snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device) loss = loss * snr_weight return loss diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 54abf697c..44447d1f0 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index f00f10eaa..91cbacc6a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -430,7 +430,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_controlnet.py b/train_controlnet.py index bbd915cb3..e0118d1c5 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -449,7 +449,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 7fbbc18ac..966999dfb 100644 --- a/train_db.py +++ b/train_db.py @@ -342,7 +342,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: diff --git a/train_network.py b/train_network.py index d50916b74..1cbed2e7b 100644 --- a/train_network.py +++ b/train_network.py @@ -812,7 +812,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 6b6e7f5a0..45a437b91 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -578,7 +578,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8dd5c672f..f77ad2eb2 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -469,7 +469,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: From e20e9f61ac0f4987667fb692e843d569b95a3f12 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 8 Nov 2023 19:09:23 +0900 Subject: [PATCH 05/15] use **kwargs and change svd() calling convention to make svd() reusable * add required attributes to model_org, model_tuned, save_to * set "*_alpha" using str(float(foo)) --- networks/extract_lora_from_models.py | 80 ++++++++++++++-------------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index dba7cd4e2..e500185db 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(args): +def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False): def str_to_dtype(p): if p == "float": return torch.float @@ -39,44 +39,44 @@ def str_to_dtype(p): return torch.bfloat16 return None - assert args.v2 != args.sdxl or ( - not args.v2 and not args.sdxl + assert v2 != sdxl or ( + not v2 and not sdxl ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" - if args.v_parameterization is None: - args.v_parameterization = args.v2 + if v_parameterization is None: + v_parameterization = v2 - save_dtype = str_to_dtype(args.save_precision) + save_dtype = str_to_dtype(save_precision) # load models - if not args.sdxl: - print(f"loading original SD model : {args.model_org}") - text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + if not sdxl: + print(f"loading original SD model : {model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org) text_encoders_o = [text_encoder_o] - print(f"loading tuned SD model : {args.model_tuned}") - text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + print(f"loading tuned SD model : {model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned) text_encoders_t = [text_encoder_t] - model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization) + model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization) else: - print(f"loading original SDXL model : {args.model_org}") + print(f"loading original SDXL model : {model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] - print(f"loading original SDXL model : {args.model_tuned}") + print(f"loading original SDXL model : {model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu" ) text_encoders_t = [text_encoder_t1, text_encoder_t2] model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha - if args.conv_dim is None: + if conv_dim is None: kwargs = {} else: - kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim} + kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim} - lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs) - lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs) + lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs) + lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs) assert len(lora_network_o.text_encoder_loras) == len( lora_network_t.text_encoder_loras ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " @@ -120,16 +120,16 @@ def str_to_dtype(p): lora_weights = {} with torch.no_grad(): for lora_name, mat in tqdm(list(diffs.items())): - # if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3 + # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3 conv2d = len(mat.size()) == 4 kernel_size = None if not conv2d else mat.size()[2:4] conv2d_3x3 = conv2d and kernel_size != (1, 1) - rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim + rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim out_dim, in_dim = mat.size()[0:2] - if args.device: - mat = mat.to(args.device) + if device: + mat = mat.to(device) # print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim) rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim @@ -178,34 +178,34 @@ def str_to_dtype(p): info = lora_network_save.load_state_dict(lora_sd) print(f"Loading extracted LoRA weights: {info}") - dir_name = os.path.dirname(args.save_to) + dir_name = os.path.dirname(save_to) if dir_name and not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) # minimum metadata net_kwargs = {} - if args.conv_dim is not None: - net_kwargs["conv_dim"] = args.conv_dim - net_kwargs["conv_alpha"] = args.conv_dim + if conv_dim is not None: + net_kwargs["conv_dim"] = str(conv_dim) + net_kwargs["conv_alpha"] = str(float(conv_dim)) metadata = { - "ss_v2": str(args.v2), + "ss_v2": str(v2), "ss_base_model_version": model_version, "ss_network_module": "networks.lora", - "ss_network_dim": str(args.dim), - "ss_network_alpha": str(args.dim), + "ss_network_dim": str(dim), + "ss_network_alpha": str(float(dim)), "ss_network_args": json.dumps(net_kwargs), } - if not args.no_metadata: - title = os.path.splitext(os.path.basename(args.save_to))[0] + if not no_metadata: + title = os.path.splitext(os.path.basename(save_to))[0] sai_metadata = sai_model_spec.build_metadata( - None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title + None, v2, v_parameterization, sdxl, True, False, time.time(), title=title ) metadata.update(sai_metadata) - lora_network_save.save_weights(args.save_to, save_dtype, metadata) - print(f"LoRA weights are saved to: {args.save_to}") + lora_network_save.save_weights(save_to, save_dtype, metadata) + print(f"LoRA weights are saved to: {save_to}") def setup_parser() -> argparse.ArgumentParser: @@ -213,7 +213,7 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む") parser.add_argument( "--v_parameterization", - type=bool, + action="store_true", default=None, help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)", ) @@ -231,16 +231,18 @@ def setup_parser() -> argparse.ArgumentParser: "--model_org", type=str, default=None, + required=True, help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors", ) parser.add_argument( "--model_tuned", type=str, default=None, + required=True, help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -264,4 +266,4 @@ def setup_parser() -> argparse.ArgumentParser: parser = setup_parser() args = parser.parse_args() - svd(args) + svd(**vars(args)) From 2c1e669bd868fa103f17823456a35304f534f2bb Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Wed, 8 Nov 2023 19:35:10 +0900 Subject: [PATCH 06/15] add min_diff, clamp_quantile args based on https://github.com/bmaltais/kohya_ss/pull/1332 https://github.com/bmaltais/kohya_ss/commit/a9ec90c40a1b390586edfba4f21a4acf6b6c09ad --- networks/extract_lora_from_models.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index e500185db..6c45d9770 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ import lora -CLAMP_QUANTILE = 0.99 -MIN_DIFF = 1e-1 +#CLAMP_QUANTILE = 0.99 +#MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,7 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, no_metadata=False): +def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False): def str_to_dtype(p): if p == "float": return torch.float @@ -91,9 +91,9 @@ def str_to_dtype(p): diff = module_t.weight - module_o.weight # Text Encoder might be same - if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF: + if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff: text_encoder_different = True - print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}") + print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}") diff = diff.float() diffs[lora_name] = diff @@ -149,7 +149,7 @@ def str_to_dtype(p): Vh = Vh[:rank, :] dist = torch.cat([U.flatten(), Vh.flatten()]) - hi_val = torch.quantile(dist, CLAMP_QUANTILE) + hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val U = U.clamp(low_val, hi_val) @@ -252,6 +252,18 @@ def setup_parser() -> argparse.ArgumentParser: help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)", ) parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う") + parser.add_argument( + "--clamp_quantile", + type=float, + default=0.99, + help="Quantile clamping value, float, (0-1). Default = 0.99", + ) + parser.add_argument( + "--min_diff", + type=float, + default=0.01, + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01", + ) parser.add_argument( "--no_metadata", action="store_true", From d0923d6710a479eff42cdf5a6e9a361ca01b2655 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Nov 2023 21:44:52 +0900 Subject: [PATCH 07/15] add caption_separator option --- finetune/tag_images_by_wd14_tagger.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 965edd7e2..fbf328e83 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -160,7 +160,9 @@ def main(args): tag_freq = {} - undesired_tags = set(args.undesired_tags.split(",")) + caption_separator = args.caption_separator + stripped_caption_separator = caption_separator.strip() + undesired_tags = set(args.undesired_tags.split(stripped_caption_separator)) def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) @@ -194,7 +196,7 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - general_tag_text += ", " + tag_name + general_tag_text += caption_separator + tag_name combined_tags.append(tag_name) elif i >= len(general_tags) and p >= args.character_threshold: tag_name = character_tags[i - len(general_tags)] @@ -203,18 +205,18 @@ def run_batch(path_imgs): if tag_name not in undesired_tags: tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1 - character_tag_text += ", " + tag_name + character_tag_text += caption_separator + tag_name combined_tags.append(tag_name) # 先頭のカンマを取る if len(general_tag_text) > 0: - general_tag_text = general_tag_text[2:] + general_tag_text = general_tag_text[len(caption_separator) :] if len(character_tag_text) > 0: - character_tag_text = character_tag_text[2:] + character_tag_text = character_tag_text[len(caption_separator) :] caption_file = os.path.splitext(image_path)[0] + args.caption_extension - tag_text = ", ".join(combined_tags) + tag_text = caption_separator.join(combined_tags) if args.append_tags: # Check if file exists @@ -224,13 +226,13 @@ def run_batch(path_imgs): existing_content = f.read().strip("\n") # Remove newlines # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()] # Check and remove repeating tags in tag_text new_tags = [tag for tag in combined_tags if tag not in existing_tags] # Create new tag_text - tag_text = ", ".join(existing_tags + new_tags) + tag_text = caption_separator.join(existing_tags + new_tags) with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") @@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") + parser.add_argument( + "--caption_separator", + type=str, + default=", ", + help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください", + ) return parser From 6d6d86260b9b97fa79d84a6115d696312905d993 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 23 Nov 2023 19:40:48 +0900 Subject: [PATCH 08/15] add Deep Shrink --- gen_img_diffusers.py | 76 ++++++++++++++++++++++++++++++++++ library/original_unet.py | 68 ++++++++++++++++++++++++++++++- library/sdxl_original_unet.py | 70 ++++++++++++++++++++++++++++++- sdxl_gen_img.py | 77 +++++++++++++++++++++++++++++++++++ 4 files changed, 288 insertions(+), 3 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a596a0494..7661538c6 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2501,6 +2501,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Extended Textual Inversion および Textual Inversionを処理する if args.XTI_embeddings: diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI @@ -3085,6 +3089,13 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -3156,10 +3167,51 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -3509,6 +3561,30 @@ def setup_parser() -> argparse.ArgumentParser: # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + return parser diff --git a/library/original_unet.py b/library/original_unet.py index 240b85951..0454f13f1 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -361,6 +361,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class SampleOutput: def __init__(self, sample): self.sample = sample @@ -1130,6 +1147,11 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1221,6 +1243,11 @@ def forward( # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1417,6 +1444,31 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1519,9 +1571,21 @@ def forward( # 2. pre-process sample = self.conv_in(sample) - # 3. down down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: + for depth, downsample_block in enumerate(self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # まあこちらのほうがわかりやすいかもしれない if downsample_block.has_cross_attention: diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 26a0af319..d51dfdbcc 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -266,6 +266,23 @@ def get_timestep_embedding( return emb +# Deep Shrink: We do not common this function, because minimize dependencies. +def resize_like(x, target, mode="bicubic", align_corners=False): + org_dtype = x.dtype + if org_dtype == torch.bfloat16: + x = x.to(torch.float32) + + if x.shape[-2:] != target.shape[-2:]: + if mode == "nearest": + x = F.interpolate(x, size=target.shape[-2:], mode=mode) + else: + x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners) + + if org_dtype == torch.bfloat16: + x = x.to(org_dtype) + return x + + class GroupNorm32(nn.GroupNorm): def forward(self, x): if self.weight.dtype != torch.float32: @@ -996,6 +1013,31 @@ def __init__( [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] ) + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1077,16 +1119,42 @@ def call_module(module, h, emb, context): # h = x.type(self.dtype) h = x - for module in self.input_blocks: + + for depth, module in enumerate(self.input_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + # print("downsample", h.shape, self.ds_ratio) + org_dtype = h.dtype + if org_dtype == torch.bfloat16: + h = h.to(torch.float32) + h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + h = call_module(module, h, emb, context) hs.append(h) h = call_module(self.middle_block, h, emb, context) for module in self.output_blocks: + # Deep Shrink + if self.ds_depth_1 is not None: + if hs[-1].shape[-2:] != h.shape[-2:]: + # print("upsample", h.shape, hs[-1].shape) + h = resize_like(h, hs[-1]) + h = torch.cat([h, hs.pop()], dim=1) h = call_module(module, h, emb, context) + # Deep Shrink: in case of depth 0 + if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]: + # print("upsample", h.shape, x.shape) + h = resize_like(h, x) + h = h.type(x.dtype) h = call_module(self.out, h, emb, context) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c31ae0072..a61fb7a89 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1696,6 +1696,10 @@ def __getattr__(self, item): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() + # Deep Shrink + if args.ds_depth_1 is not None: + unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio) + # Textual Inversionを処理する if args.textual_inversion_embeddings: token_ids_embeds1 = [] @@ -2286,6 +2290,13 @@ def scale_and_round(x): clip_prompt = None network_muls = None + # Deep Shrink + ds_depth_1 = None # means no override + ds_timesteps_1 = args.ds_timesteps_1 + ds_depth_2 = args.ds_depth_2 + ds_timesteps_2 = args.ds_timesteps_2 + ds_ratio = args.ds_ratio + prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") @@ -2393,10 +2404,51 @@ def scale_and_round(x): print(f"network mul: {network_muls}") continue + # Deep Shrink + m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 1 + ds_depth_1 = int(m.group(1)) + print(f"deep shrink depth 1: {ds_depth_1}") + continue + + m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 1 + ds_timesteps_1 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 1: {ds_timesteps_1}") + continue + + m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink depth 2 + ds_depth_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink depth 2: {ds_depth_2}") + continue + + m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink timesteps 2 + ds_timesteps_2 = int(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink timesteps 2: {ds_timesteps_2}") + continue + + m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE) + if m: # deep shrink ratio + ds_ratio = float(m.group(1)) + ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override + print(f"deep shrink ratio: {ds_ratio}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) + # override Deep Shrink + if ds_depth_1 is not None: + if ds_depth_1 < 0: + ds_depth_1 = args.ds_depth_1 or 3 + unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio) + # prepare seed if seeds is not None: # given in prompt # 数が足りないなら前のをそのまま使う @@ -2734,6 +2786,31 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", ) + + # Deep Shrink + parser.add_argument( + "--ds_depth_1", + type=int, + default=None, + help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする", + ) + parser.add_argument( + "--ds_timesteps_1", + type=int, + default=650, + help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps", + ) + parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2") + parser.add_argument( + "--ds_timesteps_2", + type=int, + default=650, + help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps", + ) + parser.add_argument( + "--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率" + ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) From 0fb9ecf1f39301d5362b7a5d143eabc949604128 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 25 Nov 2023 21:05:55 +0900 Subject: [PATCH 09/15] format by black, add ja comment --- networks/extract_lora_from_models.py | 39 +++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index 6c45d9770..6357df55d 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -13,8 +13,8 @@ import lora -#CLAMP_QUANTILE = 0.99 -#MIN_DIFF = 1e-1 +# CLAMP_QUANTILE = 0.99 +# MIN_DIFF = 1e-1 def save_to_file(file_name, model, state_dict, dtype): @@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype): torch.save(model, file_name) -def svd(model_org=None, model_tuned=None, save_to=None, dim=4, v2=None, sdxl=None, conv_dim=None, v_parameterization=None, device=None, save_precision=None, clamp_quantile=0.99, min_diff=0.01, no_metadata=False): +def svd( + model_org=None, + model_tuned=None, + save_to=None, + dim=4, + v2=None, + sdxl=None, + conv_dim=None, + v_parameterization=None, + device=None, + save_precision=None, + clamp_quantile=0.99, + min_diff=0.01, + no_metadata=False, +): def str_to_dtype(p): if p == "float": return torch.float @@ -39,9 +53,7 @@ def str_to_dtype(p): return torch.bfloat16 return None - assert v2 != sdxl or ( - not v2 and not sdxl - ), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" + assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません" if v_parameterization is None: v_parameterization = v2 @@ -199,9 +211,7 @@ def str_to_dtype(p): if not no_metadata: title = os.path.splitext(os.path.basename(save_to))[0] - sai_metadata = sai_model_spec.build_metadata( - None, v2, v_parameterization, sdxl, True, False, time.time(), title=title - ) + sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title) metadata.update(sai_metadata) lora_network_save.save_weights(save_to, save_dtype, metadata) @@ -242,7 +252,11 @@ def setup_parser() -> argparse.ArgumentParser: help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors", ) parser.add_argument( - "--save_to", type=str, default=None, required=True, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors" + "--save_to", + type=str, + default=None, + required=True, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors", ) parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)") parser.add_argument( @@ -256,13 +270,14 @@ def setup_parser() -> argparse.ArgumentParser: "--clamp_quantile", type=float, default=0.99, - help="Quantile clamping value, float, (0-1). Default = 0.99", + help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99", ) parser.add_argument( "--min_diff", type=float, default=0.01, - help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01", + help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /" + + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01", ) parser.add_argument( "--no_metadata", From c61e3bf4c9316326425d4400a7d28e8b4b810caa Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Nov 2023 18:11:30 +0900 Subject: [PATCH 10/15] make separate U-Net for inference --- gen_img_diffusers.py | 12 +- library/original_unet.py | 300 ++++++++++++++++++++++++++++------ library/sdxl_original_unet.py | 125 ++++++++++---- sdxl_gen_img.py | 11 +- 4 files changed, 366 insertions(+), 82 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 7661538c6..be43847a6 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -105,7 +105,7 @@ from networks.lora import LoRANetwork import tools.original_control_net as original_control_net from tools.original_control_net import ControlNetInfo -from library.original_unet import UNet2DConditionModel +from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -378,7 +378,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: InferUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, clip_model: CLIPModel, @@ -2196,6 +2196,7 @@ def main(args): ) original_unet.load_state_dict(unet.state_dict()) unet = original_unet + unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet) # VAEを読み込む if args.vae is not None: @@ -2352,13 +2353,20 @@ def __getattr__(self, item): vae = sli_vae del sli_vae vae.to(dtype).to(device) + vae.eval() text_encoder.to(dtype).to(device) unet.to(dtype).to(device) + + text_encoder.eval() + unet.eval() + if clip_model is not None: clip_model.to(dtype).to(device) + clip_model.eval() if vgg16_model is not None: vgg16_model.to(dtype).to(device) + vgg16_model.eval() # networkを組み込む if args.network_module: diff --git a/library/original_unet.py b/library/original_unet.py index 0454f13f1..938b0b64c 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1148,10 +1148,6 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1244,10 +1240,6 @@ def forward( res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - # Deep Shrink - if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: - hidden_states = resize_like(hidden_states, res_hidden_states) - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: @@ -1444,31 +1436,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1) - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - print("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - print( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1572,20 +1539,7 @@ def forward( sample = self.conv_in(sample) down_block_res_samples = (sample,) - for depth, downsample_block in enumerate(self.down_blocks): - # Deep Shrink - if self.ds_depth_1 is not None: - if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( - self.ds_depth_2 is not None - and depth == self.ds_depth_2 - and timesteps[0] < self.ds_timesteps_1 - and timesteps[0] >= self.ds_timesteps_2 - ): - org_dtype = sample.dtype - if org_dtype == torch.bfloat16: - sample = sample.to(torch.float32) - sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) - + for downsample_block in self.down_blocks: # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 # まあこちらのほうがわかりやすいかもしれない if downsample_block.has_cross_attention: @@ -1668,3 +1622,255 @@ def handle_unusual_timesteps(self, sample, timesteps): timesteps = timesteps.expand(sample.shape[0]) return timesteps + + +class InferUNet2DConditionModel: + def __init__(self, original_unet: UNet2DConditionModel): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # override original model's up blocks' forward method + for up_block in self.delegate.up_blocks: + if up_block.__class__.__name__ == "UpBlock2D": + + def resnet_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = resnet_wrapper(self.up_block_forward, up_block) + + elif up_block.__class__.__name__ == "CrossAttnUpBlock2D": + + def cross_attn_up_wrapper(func, block): + def forward(*args, **kwargs): + return func(block, *args, **kwargs) + + return forward + + up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block) + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in _self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def cross_attn_up_block_forward( + self, + _self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(_self.resnets, _self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # Deep Shrink + if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]: + hidden_states = resize_like(hidden_states, res_hidden_states) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if _self.upsamplers is not None: + for upsampler in _self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + ) -> Union[Dict, Tuple]: + r""" + current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink. + """ + + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dict instead of a plain tuple. + + Returns: + `SampleOutput` or `tuple`: + `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + _self = self.delegate + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある + # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する + # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い + default_overall_up_factor = 2**_self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + # 64で割り切れないときはupsamplerにサイズを伝える + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 1. time + timesteps = timestep + timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理 + + t_emb = _self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + # timestepsは重みを含まないので常にfloat32のテンソルを返す + # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある + # time_projでキャストしておけばいいんじゃね? + t_emb = t_emb.to(dtype=_self.dtype) + emb = _self.time_embedding(t_emb) + + # 2. pre-process + sample = _self.conv_in(sample) + + down_block_res_samples = (sample,) + for depth, downsample_block in enumerate(_self.down_blocks): + # Deep Shrink + if self.ds_depth_1 is not None: + if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( + self.ds_depth_2 is not None + and depth == self.ds_depth_2 + and timesteps[0] < self.ds_timesteps_1 + and timesteps[0] >= self.ds_timesteps_2 + ): + org_dtype = sample.dtype + if org_dtype == torch.bfloat16: + sample = sample.to(torch.float32) + sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype) + + # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、 + # まあこちらのほうがわかりやすいかもしれない + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + + # 4. mid + sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(_self.up_blocks): + is_final_block = i == len(_self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection + + # if we have not reached the final block and need to forward the upsample size, we do it here + # 前述のように最後のブロック以外ではupsample_sizeを伝える + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + sample = _self.conv_norm_out(sample) + sample = _self.conv_act(sample) + sample = _self.conv_out(sample) + + if not return_dict: + return (sample,) + + return SampleOutput(sample=sample) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index d51dfdbcc..babda8ec5 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -24,7 +24,7 @@ import math from types import SimpleNamespace -from typing import Optional +from typing import Any, Optional import torch import torch.utils.checkpoint from torch import nn @@ -1013,31 +1013,6 @@ def __init__( [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] ) - # Deep Shrink - self.ds_depth_1 = None - self.ds_depth_2 = None - self.ds_timesteps_1 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - - def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): - if ds_depth_1 is None: - print("Deep Shrink is disabled.") - self.ds_depth_1 = None - self.ds_timesteps_1 = None - self.ds_depth_2 = None - self.ds_timesteps_2 = None - self.ds_ratio = None - else: - print( - f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" - ) - self.ds_depth_1 = ds_depth_1 - self.ds_timesteps_1 = ds_timesteps_1 - self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 - self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 - self.ds_ratio = ds_ratio - # region diffusers compatibility def prepare_config(self): self.config = SimpleNamespace() @@ -1120,7 +1095,97 @@ def call_module(module, h, emb, context): # h = x.type(self.dtype) h = x - for depth, module in enumerate(self.input_blocks): + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + +class InferSdxlUNet2DConditionModel: + def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs): + self.delegate = original_unet + + # override original model's forward method: because forward is not called by `__call__` + # overriding `__call__` is not enough, because nn.Module.forward has a special handling + self.delegate.forward = self.forward + + # Deep Shrink + self.ds_depth_1 = None + self.ds_depth_2 = None + self.ds_timesteps_1 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + + # call original model's methods + def __getattr__(self, name): + return getattr(self.delegate, name) + + def __call__(self, *args, **kwargs): + return self.delegate(*args, **kwargs) + + def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): + if ds_depth_1 is None: + print("Deep Shrink is disabled.") + self.ds_depth_1 = None + self.ds_timesteps_1 = None + self.ds_depth_2 = None + self.ds_timesteps_2 = None + self.ds_ratio = None + else: + print( + f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" + ) + self.ds_depth_1 = ds_depth_1 + self.ds_timesteps_1 = ds_timesteps_1 + self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 + self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 + self.ds_ratio = ds_ratio + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + r""" + current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink. + """ + _self = self.delegate + + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = _self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == _self.dtype + emb = emb + _self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + + for depth, module in enumerate(_self.input_blocks): # Deep Shrink if self.ds_depth_1 is not None: if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or ( @@ -1138,9 +1203,9 @@ def call_module(module, h, emb, context): h = call_module(module, h, emb, context) hs.append(h) - h = call_module(self.middle_block, h, emb, context) + h = call_module(_self.middle_block, h, emb, context) - for module in self.output_blocks: + for module in _self.output_blocks: # Deep Shrink if self.ds_depth_1 is not None: if hs[-1].shape[-2:] != h.shape[-2:]: @@ -1156,7 +1221,7 @@ def call_module(module, h, emb, context): h = resize_like(h, x) h = h.type(x.dtype) - h = call_module(self.out, h, emb, context) + h = call_module(_self.out, h, emb, context) return h diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index a61fb7a89..78b90f8c3 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -57,7 +57,7 @@ import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction from networks.control_net_lllite import ControlNetLLLite @@ -290,7 +290,7 @@ def __init__( vae: AutoencoderKL, text_encoders: List[CLIPTextModel], tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, + unet: InferSdxlUNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], clip_skip: int, ): @@ -328,7 +328,7 @@ def __init__( self.vae = vae self.text_encoders = text_encoders self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet + self.unet: InferSdxlUNet2DConditionModel = unet self.scheduler = scheduler self.safety_checker = None @@ -1371,6 +1371,7 @@ def main(args): (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) + unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) # xformers、Hypernetwork対応 if not args.diffusers_xformers: @@ -1526,10 +1527,14 @@ def __getattr__(self, item): print("set vae_dtype to float32") vae_dtype = torch.float32 vae.to(vae_dtype).to(device) + vae.eval() text_encoder1.to(dtype).to(device) text_encoder2.to(dtype).to(device) unet.to(dtype).to(device) + text_encoder1.eval() + text_encoder2.eval() + unet.eval() # networkを組み込む if args.network_module: From 764e333fa2b560f2c5584f5ce9fb351467a7a218 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 Nov 2023 18:12:04 +0900 Subject: [PATCH 11/15] make slicing vae compatible with latest diffusers --- library/slicing_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/slicing_vae.py b/library/slicing_vae.py index 31b2bd0a4..5c4e056d3 100644 --- a/library/slicing_vae.py +++ b/library/slicing_vae.py @@ -62,7 +62,7 @@ def cat_h(sliced): return x -def resblock_forward(_self, num_slices, input_tensor, temb): +def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs): assert _self.upsample is None and _self.downsample is None assert _self.norm1.num_groups == _self.norm2.num_groups assert temb is None From 39bb319d4cac05d7da054ee726f86061e629574d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 29 Nov 2023 12:42:12 +0900 Subject: [PATCH 12/15] fix to work with cfg scale=1 --- sdxl_gen_img.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 78b90f8c3..ab5399842 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -504,7 +504,8 @@ def __call__( uncond_embeddings = tes_uncond_embs[0] for i in range(1, len(tes_text_embs)): text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + if do_classifier_free_guidance: + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 if do_classifier_free_guidance: if negative_scale is None: @@ -567,9 +568,11 @@ def __call__( text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) + if do_classifier_free_guidance: + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + vector_embeddings = torch.cat([uc_vector, c_vector]) + else: + vector_embeddings = c_vector # set timesteps self.scheduler.set_timesteps(num_inference_steps, self.device) From ee46134fa7f9b471b4aca90e4aba13102ed6cd02 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 3 Dec 2023 18:24:50 +0900 Subject: [PATCH 13/15] update readme --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index 0edaca25f..c4b91ea15 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,30 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Dec 3, 2023 / 2023/12/3 + +- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details. +- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - The default values are same as the previous version. +- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`. + - `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages. + - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. + - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. + - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. + +- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) +- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) + - 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。 +- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936) + - デフォルト値は前のバージョンと同じです。 +- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。 + - `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。 + - `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。 + - `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。 + - `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。 + ### Nov 5, 2023 / 2023/11/5 - `sdxl_train.py` now supports different learning rates for each Text Encoder. From 638ec29fe5c5b72f06816dbf7a6e943390a00abc Mon Sep 17 00:00:00 2001 From: bmaltais Date: Tue, 5 Dec 2023 07:29:23 -0500 Subject: [PATCH 14/15] Cleanup old torch1 code --- requirements_windows_torch1.txt | 5 - setup/setup_windows.py | 216 ++++++++++++++------------------ 2 files changed, 93 insertions(+), 128 deletions(-) delete mode 100644 requirements_windows_torch1.txt diff --git a/requirements_windows_torch1.txt b/requirements_windows_torch1.txt deleted file mode 100644 index 8be113361..000000000 --- a/requirements_windows_torch1.txt +++ /dev/null @@ -1,5 +0,0 @@ -torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116 # no_verify -https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps # no_verify -bitsandbytes==0.35.0 # no_verify -tensorboard==2.10.1 tensorflow==2.10.1 --r requirements.txt diff --git a/setup/setup_windows.py b/setup/setup_windows.py index 5bb90a015..f11e3a9e4 100644 --- a/setup/setup_windows.py +++ b/setup/setup_windows.py @@ -7,20 +7,20 @@ import setup_common errors = 0 # Define the 'errors' variable before using it -log = logging.getLogger('sd') +log = logging.getLogger("sd") # ANSI escape code for yellow color -YELLOW = '\033[93m' -RESET_COLOR = '\033[0m' +YELLOW = "\033[93m" +RESET_COLOR = "\033[0m" def cudann_install(): cudnn_src = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '..\cudnn_windows' + os.path.dirname(os.path.realpath(__file__)), "..\cudnn_windows" ) - cudnn_dest = os.path.join(sysconfig.get_paths()['purelib'], 'torch', 'lib') + cudnn_dest = os.path.join(sysconfig.get_paths()["purelib"], "torch", "lib") - log.info(f'Checking for CUDNN files in {cudnn_dest}...') + log.info(f"Checking for CUDNN files in {cudnn_dest}...") if os.path.exists(cudnn_src): if os.path.exists(cudnn_dest): # check for different files @@ -34,9 +34,9 @@ def cudann_install(): shutil.copy2(src_file, cudnn_dest) else: shutil.copy2(src_file, cudnn_dest) - log.info('Copied CUDNN 8.6 files to destination') + log.info("Copied CUDNN 8.6 files to destination") else: - log.warning(f'Destination directory {cudnn_dest} does not exist') + log.warning(f"Destination directory {cudnn_dest} does not exist") else: log.error(f'Installation Failed: "{cudnn_src}" could not be found.') @@ -50,18 +50,16 @@ def sync_bits_and_bytes_files(): """ # Only execute on Windows - if os.name != 'nt': - print('This function is only applicable to Windows OS.') + if os.name != "nt": + print("This function is only applicable to Windows OS.") return try: - log.info(f'Copying bitsandbytes files...') + log.info(f"Copying bitsandbytes files...") # Define source and destination directories - source_dir = os.path.join(os.getcwd(), 'bitsandbytes_windows') + source_dir = os.path.join(os.getcwd(), "bitsandbytes_windows") - dest_dir_base = os.path.join( - sysconfig.get_paths()['purelib'], 'bitsandbytes' - ) + dest_dir_base = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes") # Clear file comparison cache filecmp.clear_cache() @@ -71,8 +69,8 @@ def sync_bits_and_bytes_files(): source_file_path = os.path.join(source_dir, file) # Decide the destination directory based on file name - if file in ('main.py', 'paths.py'): - dest_dir = os.path.join(dest_dir_base, 'cuda_setup') + if file in ("main.py", "paths.py"): + dest_dir = os.path.join(dest_dir_base, "cuda_setup") else: dest_dir = dest_dir_base @@ -83,46 +81,19 @@ def sync_bits_and_bytes_files(): source_file_path, dest_file_path ): log.debug( - f'Skipping {source_file_path} as it already exists in {dest_dir}' + f"Skipping {source_file_path} as it already exists in {dest_dir}" ) else: # Copy file from source to destination, maintaining original file's metadata - log.debug(f'Copy {source_file_path} to {dest_dir}') + log.debug(f"Copy {source_file_path} to {dest_dir}") shutil.copy2(source_file_path, dest_dir) except FileNotFoundError as fnf_error: - log.error(f'File not found error: {fnf_error}') + log.error(f"File not found error: {fnf_error}") except PermissionError as perm_error: - log.error(f'Permission error: {perm_error}') + log.error(f"Permission error: {perm_error}") except Exception as e: - log.error(f'An unexpected error occurred: {e}') - - -# def install_kohya_ss_torch1(): -# setup_common.check_repo_version() -# setup_common.check_python() - -# # Upgrade pip if needed -# setup_common.install('--upgrade pip') - -# if setup_common.check_torch() == 2: -# input( -# f'{YELLOW}\nTorch 2 is already installed in the venv. To install Torch 1 delete the venv and re-run setup.bat\n\nHit enter to continue...{RESET_COLOR}' -# ) -# return - -# # setup_common.install( -# # 'torch==1.12.1+cu116 torchvision==0.13.1+cu116 --index-url https://download.pytorch.org/whl/cu116', -# # 'torch torchvision' -# # ) -# # setup_common.install( -# # 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl -U -I --no-deps', -# # 'xformers-0.0.14' -# # ) -# setup_common.install_requirements('requirements_windows_torch1.txt', check_no_verify_flag=False) -# sync_bits_and_bytes_files() -# setup_common.configure_accelerate(run_accelerate=True) -# # run_cmd(f'accelerate config') + log.error(f"An unexpected error occurred: {e}") def install_kohya_ss_torch2(): @@ -130,111 +101,110 @@ def install_kohya_ss_torch2(): setup_common.check_python() # Upgrade pip if needed - setup_common.install('--upgrade pip') - - # if setup_common.check_torch() == 1: - # input( - # f'{YELLOW}\nTorch 1 is already installed in the venv. To install Torch 2 delete the venv and re-run setup.bat\n\nHit any key to acknowledge.{RESET_COLOR}' - # ) - # return - - # setup_common.install( - # 'torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118', - # 'torch torchvision' - # ) - setup_common.install_requirements('requirements_windows_torch2.txt', check_no_verify_flag=False) - # install('https://huggingface.co/r4ziel/xformers_pre_built/resolve/main/triton-2.0.0-cp310-cp310-win_amd64.whl', 'triton', reinstall=reinstall) + setup_common.install("--upgrade pip") + + setup_common.install_requirements( + "requirements_windows_torch2.txt", check_no_verify_flag=False + ) + sync_bits_and_bytes_files() setup_common.configure_accelerate(run_accelerate=True) + # run_cmd(f'accelerate config') def install_bitsandbytes_0_35_0(): - log.info('Installing bitsandbytes 0.35.0...') - setup_common.install('--upgrade bitsandbytes==0.35.0', 'bitsandbytes 0.35.0', reinstall=True) + log.info("Installing bitsandbytes 0.35.0...") + setup_common.install( + "--upgrade bitsandbytes==0.35.0", "bitsandbytes 0.35.0", reinstall=True + ) sync_bits_and_bytes_files() + def install_bitsandbytes_0_40_1(): - log.info('Installing bitsandbytes 0.40.1...') - setup_common.install('--upgrade https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl', 'bitsandbytes 0.40.1', reinstall=True) + log.info("Installing bitsandbytes 0.40.1...") + setup_common.install( + "--upgrade https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.1.post1-py3-none-win_amd64.whl", + "bitsandbytes 0.40.1", + reinstall=True, + ) + def install_bitsandbytes_0_41_1(): - log.info('Installing bitsandbytes 0.41.1...') - setup_common.install('--upgrade https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl', 'bitsandbytes 0.41.1', reinstall=True) + log.info("Installing bitsandbytes 0.41.1...") + setup_common.install( + "--upgrade https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl", + "bitsandbytes 0.41.1", + reinstall=True, + ) + def main_menu(): setup_common.clear_screen() while True: - print('\nKohya_ss GUI setup menu:\n') - print('1. Install kohya_ss gui') - print('2. (Optional) Install cudann files (avoid unless you really need it)') - print('3. (Optional) Install specific bitsandbytes versions') - print('4. (Optional) Manually configure accelerate') - print('5. (Optional) Start Kohya_ss GUI in browser') - print('6. Quit') - - choice = input('\nEnter your choice: ') - print('') - - if choice == '1': + print("\nKohya_ss GUI setup menu:\n") + print("1. Install kohya_ss gui") + print("2. (Optional) Install cudann files (avoid unless you really need it)") + print("3. (Optional) Install specific bitsandbytes versions") + print("4. (Optional) Manually configure accelerate") + print("5. (Optional) Start Kohya_ss GUI in browser") + print("6. Quit") + + choice = input("\nEnter your choice: ") + print("") + + if choice == "1": install_kohya_ss_torch2() - # while True: - # print('1. Torch 1 (legacy, no longer supported. Will be removed in v21.9.x)') - # print('2. Torch 2 (recommended)') - # print('3. Cancel') - # choice_torch = input('\nEnter your choice: ') - # print('') - - # if choice_torch == '1': - # install_kohya_ss_torch1() - # break - # elif choice_torch == '2': - # install_kohya_ss_torch2() - # break - # elif choice_torch == '3': - # break - # else: - # print('Invalid choice. Please enter a number between 1-3.') - elif choice == '2': + elif choice == "2": cudann_install() - elif choice == '3': + elif choice == "3": while True: - print('1. (Optional) Force installation of bitsandbytes 0.35.0') - print('2. (Optional) Force installation of bitsandbytes 0.40.1 for new optimizer options support and pre-bugfix results') - print('3. (Optional) Force installation of bitsandbytes 0.41.1 for new optimizer options support') - print('4. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)') - print('5. Cancel') - choice_torch = input('\nEnter your choice: ') - print('') - - if choice_torch == '1': + print("1. (Optional) Force installation of bitsandbytes 0.35.0") + print( + "2. (Optional) Force installation of bitsandbytes 0.40.1 for new optimizer options support and pre-bugfix results" + ) + print( + "3. (Optional) Force installation of bitsandbytes 0.41.1 for new optimizer options support" + ) + print( + "4. (Danger) Install bitsandbytes-windows (this package has been reported to cause issues for most... avoid...)" + ) + print("5. Cancel") + choice_torch = input("\nEnter your choice: ") + print("") + + if choice_torch == "1": install_bitsandbytes_0_35_0() break - elif choice_torch == '2': + elif choice_torch == "2": install_bitsandbytes_0_40_1() break - elif choice_torch == '3': + elif choice_torch == "3": install_bitsandbytes_0_41_1() break - elif choice_torch == '4': - setup_common.install('--upgrade bitsandbytes-windows', reinstall=True) + elif choice_torch == "4": + setup_common.install( + "--upgrade bitsandbytes-windows", reinstall=True + ) break - elif choice_torch == '5': + elif choice_torch == "5": break else: - print('Invalid choice. Please enter a number between 1-3.') - elif choice == '4': - setup_common.run_cmd('accelerate config') - elif choice == '5': - subprocess.Popen('start cmd /k .\gui.bat --inbrowser', shell=True) # /k keep the terminal open on quit. /c would close the terminal instead - elif choice == '6': - print('Quitting the program.') + print("Invalid choice. Please enter a number between 1-3.") + elif choice == "4": + setup_common.run_cmd("accelerate config") + elif choice == "5": + subprocess.Popen( + "start cmd /k .\gui.bat --inbrowser", shell=True + ) # /k keep the terminal open on quit. /c would close the terminal instead + elif choice == "6": + print("Quitting the program.") break else: - print('Invalid choice. Please enter a number between 1-5.') + print("Invalid choice. Please enter a number between 1-5.") -if __name__ == '__main__': +if __name__ == "__main__": setup_common.ensure_base_requirements() setup_common.setup_logging() main_menu() From 06eed6977b1f5f3cbaadc3115c571120e7293cc3 Mon Sep 17 00:00:00 2001 From: bmaltais Date: Wed, 6 Dec 2023 17:47:13 -0500 Subject: [PATCH 15/15] Add GLoRA support --- .release | 2 +- README.md | 3 ++- lora_gui.py | 43 +++++++++++++++++++++++++++++++++---------- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/.release b/.release index 6a112b855..d542ef359 100644 --- a/.release +++ b/.release @@ -1 +1 @@ -v22.2.2 +v22.3.0 diff --git a/README.md b/README.md index a8f880334..14244a46e 100644 --- a/README.md +++ b/README.md @@ -651,7 +651,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b ## Change History -* 2023/12/05 (v22.3.0) +* 2023/12/06 (v22.3.0) - Merge sd-scripts updates: - `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913) - Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934) @@ -663,6 +663,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b - `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages. - `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink. - `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available. + - Add GLoRA support * 2023/12/03 (v22.2.2) - Update Lycoris module to 2.0.0 (https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/README.md) diff --git a/lora_gui.py b/lora_gui.py index c7b05b602..1a8199d0e 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -732,6 +732,17 @@ def train_model( run_cmd += f" --network_module=lycoris.kohya" run_cmd += f' --network_args "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=locon"' + if LoRA_type == "LyCORIS/GLoRA": + try: + import lycoris + except ModuleNotFoundError: + log.info( + "\033[1;31mError:\033[0m The required module 'lycoris_lora' is not installed. Please install by running \033[33mupgrade.ps1\033[0m before running this program." + ) + return + run_cmd += f" --network_module=lycoris.kohya" + run_cmd += f' --network_args "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "algo=glora"' + if LoRA_type == "LyCORIS/LoHa": try: import lycoris @@ -741,7 +752,7 @@ def train_model( ) return run_cmd += f" --network_module=lycoris.kohya" - run_cmd += f' --network_args "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "use_cp={use_cp}" "algo=loha"' + run_cmd += f' --network_args "preset={LyCORIS_preset}" "conv_dim={conv_dim}" "conv_alpha={conv_alpha}" "use_tucker={use_cp}" "algo=loha"' # This is a hack to fix a train_network LoHA logic issue if not network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' @@ -787,7 +798,7 @@ def train_model( # This is a hack to fix a train_network LoHA logic issue if not network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' - + if LoRA_type == "LyCORIS/Native Fine-Tuning": try: import lycoris @@ -797,11 +808,13 @@ def train_model( ) return run_cmd += f" --network_module=lycoris.kohya" - run_cmd += f' --network_args "preset={LyCORIS_preset}" "algo=full" "train_norm=True"' + run_cmd += ( + f' --network_args "preset={LyCORIS_preset}" "algo=full" "train_norm=True"' + ) # This is a hack to fix a train_network LoHA logic issue if not network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' - + if LoRA_type == "LyCORIS/Diag-OFT": try: import lycoris @@ -815,7 +828,6 @@ def train_model( # This is a hack to fix a train_network LoHA logic issue if not network_dropout > 0.0: run_cmd += f' --network_dropout="{network_dropout}"' - if LoRA_type in ["Kohya LoCon", "Standard"]: kohya_lora_var_list = [ @@ -967,7 +979,7 @@ def train_model( if full_bf16: run_cmd += f" --full_bf16" - + if debiased_estimation_loss: run_cmd += " --debiased_estimation_loss" @@ -1139,6 +1151,7 @@ def list_presets(path): "LyCORIS/DyLoRA", "LyCORIS/iA3", "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", @@ -1158,7 +1171,8 @@ def list_presets(path): "unet-convblock-only", ], value="full", - visible=False, interactive=True + visible=False, + interactive=True # info="https://github.com/KohakuBlueleaf/LyCORIS/blob/0006e2ffa05a48d8818112d9f70da74c0cd30b99/docs/Preset.md" ) with gr.Box(): @@ -1235,7 +1249,7 @@ def list_presets(path): with gr.Row() as LoRA_dim_alpha: network_dim = gr.Slider( minimum=1, - maximum=100000, # 512 if not LoRA_type == "LyCORIS/LoKr" else 100000, + maximum=100000, # 512 if not LoRA_type == "LyCORIS/LoKr" else 100000, label="Network Rank (Dimension)", value=8, step=1, @@ -1254,7 +1268,7 @@ def list_presets(path): # locon= gr.Checkbox(label='Train a LoCon instead of a general LoRA (does not support v2 base models) (may not be able to some utilities now)', value=False) conv_dim = gr.Slider( minimum=0, - maximum=100000, # 512 if not LoRA_type == "LyCORIS/LoKr" else 100000, + maximum=100000, # 512 if not LoRA_type == "LyCORIS/LoKr" else 100000, value=1, step=1, label="Convolution Rank (Dimension)", @@ -1322,6 +1336,7 @@ def update_LoRA_settings(LoRA_type): "LoRA-FA", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", @@ -1340,6 +1355,7 @@ def update_LoRA_settings(LoRA_type): "LyCORIS/LoHa", "LyCORIS/LoKr", "LyCORIS/LoCon", + "LyCORIS/GLoRA", }, gr.Row, ), @@ -1365,6 +1381,7 @@ def update_LoRA_settings(LoRA_type): "LoRA-FA", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoHa", "LyCORIS/LoCon", "LyCORIS/LoKr", @@ -1380,6 +1397,7 @@ def update_LoRA_settings(LoRA_type): "LoRA-FA", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoHa", "LyCORIS/LoCon", "LyCORIS/LoKr", @@ -1395,6 +1413,7 @@ def update_LoRA_settings(LoRA_type): "LoRA-FA", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoHa", "LyCORIS/LoCon", "LyCORIS/LoKr", @@ -1406,6 +1425,7 @@ def update_LoRA_settings(LoRA_type): { "LyCORIS/DyLoRA", "LyCORIS/LoHa", + "LyCORIS/GLoRA", "LyCORIS/LoCon", "LyCORIS/LoKr", }, @@ -1420,6 +1440,7 @@ def update_LoRA_settings(LoRA_type): "Kohya LoCon", "LoRA-FA", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoHa", "LyCORIS/LoCon", "LyCORIS/LoKr", @@ -1435,6 +1456,7 @@ def update_LoRA_settings(LoRA_type): "LoRA-FA", "LyCORIS/Diag-OFT", "LyCORIS/DyLoRA", + "LyCORIS/GLoRA", "LyCORIS/LoHa", "LyCORIS/LoCon", "LyCORIS/LoKr", @@ -1467,13 +1489,14 @@ def update_LoRA_settings(LoRA_type): "LyCORIS/DyLoRA", "LyCORIS/iA3", "LyCORIS/Diag-OFT", + "LyCORIS/GLoRA", "LyCORIS/LoCon", "LyCORIS/LoHa", "LyCORIS/LoKr", "LyCORIS/Native Fine-Tuning", }, gr.Dropdown, - ) + ), } results = []