-
Notifications
You must be signed in to change notification settings - Fork 877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficient Diffusion Training via Min-SNR Weighting Strategy #308
Conversation
@bmaltais I thought you may be interested in trying this. |
Thank you, look interesting. |
what does the comparison show when the paper talks about converging faster? |
One possible improvement to the code could be to create a function in a seperate custom_train_functions.py file to be called by all trainer: import torch
def apply_snr_weight(loss, noisy_latents, latents, gamma):
gamma = gamma
if gamma:
sigma = torch.sub(noisy_latents, latents)
zeros = torch.zeros_like(sigma)
alpha_mean_sq = torch.nn.functional.mse_loss(latents.float(), zeros.float(), reduction="none").mean([1, 2, 3])
sigma_mean_sq = torch.nn.functional.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3])
snr = torch.div(alpha_mean_sq, sigma_mean_sq)
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float()
loss = loss * snr_weight
return loss That way all you need is add this to each trainers: from library.custom_train_functions import apply_snr_weight
loss = apply_snr_weight(loss, noisy_latents, latents, args.min_snr_gamma) |
I gave it a try and it does appear to have a significantly positive effect on the training. This is a keeper! |
I do like the cleanliness of that. It could also have the associated argparse lines there as well. |
I didn't so much notice faster convergence as I did better likeness, but that may have to do with the fact that we're fine-tuning, not training a model from scratch. In this implementation, I noticed less overtraining on superficial details like eyebags and facial imperfections, which can occur when the random timesteps are very low. Low timesteps/low noise force back propagation into both a big update step, and that step to be superficial. |
what are the recommended values from your testing? |
@TingTingin During my testing I added |
I would think that you will need to add the min-snr to all trainer before @kohya-ss can merge. In the current state it would only apply to one of the 4 trainer... |
Just sitting down to refactor it now. |
Alright, I refactored into a separate file per @bmaltais suggestion, and added the function to all 4 trainers. I do not have a dataset on hand to test fine_tune.py, but the other 3 function as intended. |
After reading the paper, I thought that SNR calculation code by author. The |
@laksjdjf thanks for linking me the official code, it wasn't out at the time I coded this up. I'm not math guy by any stretch, but the code sanity checks, and has been checked over by better math brains than mine. If there are improvements, I'm completely game. As to their code, I see I was just asking said math brain today how we could determine the SNR timestep correlation, but we hadn't worked it out. If the ratio of mean squares is not accurate enough, we can change systems. If you add a |
You know what? I think you're right, let me do some tests to see where the ratios end up. |
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise and the variance of noise and original_samples is 1, so the SNR of noisy latents becomes
I did test to see snr ratio by below code. import torch
from diffusers import DDPMScheduler
import matplotlib.pyplot as plt
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
timesteps = torch.arange(0,1000)
def get_snr(
scheduler,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
return (sqrt_alpha_prod / sqrt_one_minus_alpha_prod) ** 2
snr = get_snr(scheduler, timesteps)
plt.xlabel("timesteps")
plt.ylabel("snr")
plt.plot(timesteps,snr) If it is limited to between 100 and 1000 because it is not clear, I am not sure if these numbers are reasonable. |
Pardon my git illiteracy there. Implemented the necessary changes to align with the authors calculations
Your graphs are correct @laksjdjf and here is a google sheet with a graph of snr_weight scalable by gamma Thank you for bringing this to my attention! |
Trying to git push at 1 am is too much for my brain I guess. Thanks to the kind soul from discord who caught my mistake. |
Thank you @AI-Casanova and everyone for the great discussion! I don't fully understand the background of the theory, but the results are excellent! I think this PR is ready to merge. I will be merging this today. If you have any concerns, please let me know. |
Great work implementing this! Very excited to do some comparisons with this modified loss calculation. I did run into a small issue due while testing this repo due to datatype incompatibility with numpy (specifically bf16):
It should be an easy fix by just changing the sqrt functions to their torch equivalents. It also may be slightly cleaner notation to change import torch
import argparse
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.to(loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float().to(loss.device) #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper.") Thanks again for finding this paper and sharing with everyone! |
Thanks for testing, @mgz-dev, and finding the TypeError with bf16. Switching to Ran a quick test with fp16 to ensure that the loss graphs were a match. Let me know if you find anything else. |
Current implementation is not applicable to SD2 (v-pred), it only considered SD1 (eps-pred). Here's an improved implementation I thought: def min_snr_weight(scheduler, t, gamma):
alpha_cp = scheduler.alphas_cumprod
sigma_pow_2 = 1.0 - alpha_cp
snr = (alpha_cp ** 2) / sigma_pow_2
snr_t = snr[t]
match scheduler.config.prediction_type:
case "epsilon":
min_snr_w = torch.minimum(gamma / snr_t, torch.ones_like(t, dtype=torch.float32))
case "sample":
min_snr_w = torch.minimum(snr_t, torch.full_like(t, gamma, dtype=torch.float32))
case "v_prediction":
min_snr_w = torch.minimum(snr_t + 1, torch.full_like(t, gamma, dtype=torch.float32))
case _:
raise Exception("Unknown prediction type")
return min_snr_w |
@CCRcmcpe Gimme a bit to think through this, I did forget to disclaim V2 incompatibility. |
Also @CCRcmcpe wouldn't the velocity factor be `min(SNR,gamma)/(SNR+1) |
Yeah I'm mistaken. |
Has anything been done regarding the v_prediction model? |
Remove legacy 8bit adam checkbox
Implementation of https://arxiv.org/abs/2303.09556
Low noise timesteps produce outsized loss (as I discovered on my own here #294), which can lead to training instability as single samples make large steps in a direction that may not be advantageous.
This paper introduces a scaling factor gamma, accessible with the new argument
--min_snr_gamma
that lowers the weight of these low timesteps by calculating the signal to noise ratio.From the highest loss to lowest gamma=20,5,4,3,2,1
(Generated from the losses above)