Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient Diffusion Training via Min-SNR Weighting Strategy #308

Merged
merged 7 commits into from
Mar 26, 2023
8 changes: 7 additions & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

def train(args):
train_util.verify_training_args(args)
Expand Down Expand Up @@ -309,6 +310,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
Expand Down Expand Up @@ -401,6 +405,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)


parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
Expand Down
18 changes: 18 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import argparse

def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss

def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.")
2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,7 +2001,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)


def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
Expand Down
8 changes: 7 additions & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
ConfigSanitizer,
BlueprintGenerator,
)

import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

def train(args):
train_util.verify_training_args(args)
Expand Down Expand Up @@ -296,6 +297,10 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)


loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

accelerator.backward(loss)
Expand Down Expand Up @@ -395,6 +400,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--no_token_padding",
Expand Down
9 changes: 7 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight


# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
Expand Down Expand Up @@ -492,7 +495,6 @@ def train(args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)

if accelerator.is_main_process:
accelerator.init_trackers("network_train")

Expand Down Expand Up @@ -534,7 +536,6 @@ def train(args):
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
Expand All @@ -554,6 +555,9 @@ def train(args):

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

Expand Down Expand Up @@ -658,6 +662,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight

imagenet_templates_small = [
"a photo of a {}",
Expand Down Expand Up @@ -383,6 +385,9 @@ def train(args):

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
Expand Down Expand Up @@ -540,6 +545,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)

parser.add_argument(
"--save_model_as",
Expand Down