From a26522597256abc1e6ddc40936dd0e6c2179c266 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Mon, 20 Mar 2023 22:51:38 +0000 Subject: [PATCH 1/5] Min-SNR Weighting Strategy --- library/train_util.py | 2 +- train_network.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 7d311827d..e68444a0e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1927,7 +1927,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_network.py b/train_network.py index 7f910df40..5cb08f15b 100644 --- a/train_network.py +++ b/train_network.py @@ -548,6 +548,16 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + gamma = args.min_snr_gamma + if gamma: + sigma = torch.sub(noisy_latents, latents) #find noise as applied + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square + snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From 64c923230e54a92b1a132e4918dad190935da461 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Wed, 22 Mar 2023 01:25:49 +0000 Subject: [PATCH 2/5] Min-SNR Weighting Strategy: Refactored and added to all trainers --- fine_tune.py | 8 +++++++- library/custom_train_functions.py | 17 +++++++++++++++++ library/train_util.py | 2 +- train_db.py | 7 ++++++- train_network.py | 17 ++++++----------- train_textual_inversion.py | 6 ++++++ 6 files changed, 43 insertions(+), 14 deletions(-) create mode 100644 library/custom_train_functions.py diff --git a/fine_tune.py b/fine_tune.py index 1acf478f4..ff33eb9c9 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -19,7 +19,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -304,6 +305,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] @@ -396,6 +400,8 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py new file mode 100644 index 000000000..f60ec7436 --- /dev/null +++ b/library/custom_train_functions.py @@ -0,0 +1,17 @@ +import torch +import argparse + +def apply_snr_weight(loss, latents, noisy_latents, gamma): + sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler + zeros = torch.zeros_like(sigma) + alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment + snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + loss = loss * snr_weight + print(snr_weight) + return loss + +def add_custom_train_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") \ No newline at end of file diff --git a/library/train_util.py b/library/train_util.py index ffe81d693..a0e98cb12 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1963,7 +1963,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" ) - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: diff --git a/train_db.py b/train_db.py index 527f8e9bc..ee9beda9d 100644 --- a/train_db.py +++ b/train_db.py @@ -21,7 +21,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -291,6 +292,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) @@ -390,6 +394,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--no_token_padding", diff --git a/train_network.py b/train_network.py index dce706186..715da8c11 100644 --- a/train_network.py +++ b/train_network.py @@ -23,7 +23,8 @@ ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight def collate_fn(examples): return examples[0] @@ -548,16 +549,9 @@ def train(args): loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - gamma = args.min_snr_gamma - if gamma: - sigma = torch.sub(noisy_latents, latents) #find noise as applied - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares - gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper - loss = loss * snr_weight + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -662,6 +656,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument( diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 85f0d57c3..5fe662f6a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -17,6 +17,8 @@ ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [ "a photo of a {}", @@ -377,6 +379,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -534,6 +539,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as", From a3c7d711e4f1160fb873755c77074a89f5bf700a Mon Sep 17 00:00:00 2001 From: AI-Casanova <54461896+AI-Casanova@users.noreply.github.com> Date: Tue, 21 Mar 2023 20:38:27 -0500 Subject: [PATCH 3/5] Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation --- library/custom_train_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f60ec7436..5e880c9a7 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -10,8 +10,8 @@ def apply_snr_weight(loss, latents, noisy_latents, gamma): gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper loss = loss * snr_weight - print(snr_weight) + #print(snr_weight) return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") \ No newline at end of file + parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") From 518a18aeff9c2f9a830565e0623729fd938590d3 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Thu, 23 Mar 2023 12:34:49 +0000 Subject: [PATCH 4/5] (ACTUAL) Min-SNR Weighting Strategy: Fixed SNR calculation to authors implementation --- fine_tune.py | 2 +- library/custom_train_functions.py | 20 ++++++++++++-------- train_db.py | 3 ++- train_network.py | 4 +--- train_textual_inversion.py | 2 +- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff33eb9c9..45f4b9db2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -306,7 +306,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 5e880c9a7..b080b40c6 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,16 +1,20 @@ import torch import argparse +import numpy as np -def apply_snr_weight(loss, latents, noisy_latents, gamma): - sigma = torch.sub(noisy_latents, latents) #find noise as applied by scheduler - zeros = torch.zeros_like(sigma) - alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) #trick to get Mean Square/Second Moment - snr = torch.div(alpha_mean_sq,sigma_mean_sq) #Signal to Noise Ratio = ratio of Mean Squares + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.to(loss.device) + snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper loss = loss * snr_weight - #print(snr_weight) return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): diff --git a/train_db.py b/train_db.py index ee9beda9d..52195b92a 100644 --- a/train_db.py +++ b/train_db.py @@ -293,7 +293,8 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 715da8c11..145dd600d 100644 --- a/train_network.py +++ b/train_network.py @@ -489,7 +489,6 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) - if accelerator.is_main_process: accelerator.init_trackers("network_train") @@ -529,7 +528,6 @@ def train(args): # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -551,7 +549,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 5fe662f6a..0694dbb6a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -381,7 +381,7 @@ def train(args): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, latents, noisy_latents, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights From 4c06bfad60be71f6f13f2d14413694c2e0be7813 Mon Sep 17 00:00:00 2001 From: AI-Casanova Date: Sun, 26 Mar 2023 00:01:29 +0000 Subject: [PATCH 5/5] Fix for TypeError from bf16 precision: Thanks to mgz-dev --- library/custom_train_functions.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index b080b40c6..fd4f6156b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,21 +1,18 @@ import torch import argparse -import numpy as np - def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): - alphas_cumprod = noise_scheduler.alphas_cumprod.cpu() - sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 - all_snr.to(loss.device) snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr) - snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper + snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper loss = loss * snr_weight return loss def add_custom_train_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--min_snr_gamma", type=float, default=0, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") + parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")