Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I have reason to believe "scale v-loss like epsilon loss" and Min-SNR-Gamma are implemented wrong. #673

Closed
drhead opened this issue Jul 21, 2023 · 13 comments
Labels
help wanted Extra attention is needed

Comments

@drhead
Copy link

drhead commented Jul 21, 2023

I've been training a model using Kohya's implementation of Min-SNR-Gamma and the more recent option for scaling v-prediction like epsilon loss. I am also training it on v-prediction and zero terminal SNR, which is important.

I first found that the v-loss rescaling actually prevents a zero terminal SNR model from becoming able to produce fully black images even after about 5 million training samples, but it immediately learned it once I turned that setting off. However, others still noticed that it nevertheless improved quality in other areas, suggesting that there was likely a proper way to correct the flaw.

Looking further into the paper, it seems that the authors for Min-SNR-Gamma stated that the formula should be modified for V-loss, but may have been somewhat unclear in their wording:
image

Kohya implements the simplified formula on the right hand side.

I have implemented and tested this alternative function for min_snr_gamma, based on the middle formula -- it is the same as the middle formula except the denominator is replaced with SNR(t) + 1. My implementation is in JAX since that is what my current training script uses, but converting it to Pytorch should pretty much just be removing the expand_dims line and replacing jnp with torch:

    def apply_snr_weight_alt(loss, timesteps, noise_scheduler, gamma):
        snr = jnp.stack([noise_scheduler.all_snr[t] for t in timesteps])
        min_snr_gamma = jnp.minimum(snr, gamma)
        snr_weight = jnp.divide(min_snr_gamma, snr + 1).astype(jnp.float32)
        snr_weight = jnp.expand_dims(snr_weight, axis=(1, 2, 3)) # likely unnecessary for pytorch
        loss = loss * snr_weight
        return loss

This is, as far as I am aware, the correct function for min_snr_gamma for V-loss. It has performed well in my tests and has improved quality of my outputs without compromising on contrast range. It serves the same purpose as the "scale v-loss like epsilon loss" option and results in loss metrics that are in the same range as epsilon loss. It should be how Min-SNR-Gamma behaves under v-prediction and should fully replace the "scale v-loss like epsilon loss" option.

Others who I worked on this problem with have tested this function and found that it improves performance compared to using the current implementation of the aforementioned options. If you need a model to test it on, I can release one of my prototypes (an SD 1.5 model trained on V-loss and zero terminal SNR) for testing purposes. I would imagine SD 2.1 768-v would work as well.

@feffy380
Copy link
Contributor

There's some discussion about this in the original pull request for Min-SNR-gamma and AI-Casanova (the PR's author) also thinks it should probably be min(SNR,gamma)/(SNR+1) for vpred

@kohya-ss
Copy link
Owner

Thank you very much for this!

I am not a math person and my understanding may be incorrect, but does this mean we can modify the following?

def apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
    snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)  # from paper
    loss = loss * snr_weight
    return loss


def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
    if not v_prediction:
        return apply_snr_weight_noise_pred(loss, timesteps, noise_scheduler, gamma)

    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    min_snr_gamma = torch.minimum(snr, gamma)
    snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
    loss = loss * snr_weight
    return loss


# we can remove this function
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
    snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])  # batch_size
    snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)  # if timestep is 0, snr_t is inf, so limit it to 1000
    scale = snr_t / (snr_t + 1)

    loss = loss * scale
    return loss

@drhead
Copy link
Author

drhead commented Jul 23, 2023

That should work, but I think the cleanest way to implement it would be to change the denominator based on v-prediction. If it is epsilon-prediction, it should be snr, if v-prediction it should be snr + 1. The epsilon prediction case for that should be equivalent to the original implementation.

def apply_snr_weight_alt(v_prediction, loss, timesteps, noise_scheduler, gamma):
    snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
    min_snr_gamma = torch.minimum(snr, gamma)
    if v_prediction:
        snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
    else:
        snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
    loss = loss * snr_weight
    return loss

@kohya-ss
Copy link
Owner

Thank you for clarification!

The formulas seem to say that if we apply the current apply_snr_weight (apply_snr_weight_noise_pred above) and scale_v_prediction_loss_like_noise_prediction at the same time, we should be fine for the v-prediction case. Is this correct?

Both options can be specified at the same time.

@drhead
Copy link
Author

drhead commented Jul 24, 2023

No, the scale_v_prediction_loss_like_noise_prediction doesn't function the same as the v prediction path of my apply_snr_weight_alt, in part due to the clipping at timestep 0 which is the likely cause of the interference with zero terminal SNR I mentioned. apply_snr_weight_alt ensures that infinite SNR at timestep 0 doesn't cause problems.

edit: I'd also like to emphasize that the formula used in the v-prediction code path should be outright the correct implementation of min-SNR-gamma, as in there shouldn't be a separate loss rescale function for v-prediction that is optional. min-SNR-gamma used on v-prediction should always behave like this. Compatibility is a possible concern, but at least from the testing I've seen so far this implementation gives better results than the two loss rescales.

@laksjdjf
Copy link
Contributor

By the way, is clipping necessary?
The snr when timestep is 0 is the snr when x_1, not x_0. Therefore, it is not inf.

scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")

def get_snr(
    scheduler, 
    timesteps: torch.IntTensor,
) -> torch.FloatTensor:

    sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()

    sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()

    return (sqrt_alpha_prod / sqrt_one_minus_alpha_prod) ** 2

get_snr(scheduler, torch.tensor(0))

# tensor([1175.4406])

@dill-shower
Copy link

Any updates?

@kohya-ss
Copy link
Owner

As laksjdjf wrote, I believe it is OK when we specify both apply_snr_weight and scale_v_prediction_loss_like_noise_prediction options.

@feffy380
Copy link
Contributor

Unfortunately not. Combining both leads to loss being scaled twice. scale_v_prediction_loss_like_noise_prediction is applying the v-pred version of the formula from the paper but with a hardcoded gamma=1000.
They really should be one function like @drhead's example.

@drhead
Copy link
Author

drhead commented Sep 18, 2023

I should reiterate that apply_snr_weight as it exists currently is an incorrect implementation of the paper when training using v-prediction. Using both that and scale_v_prediction_loss_like_noise_prediction is not mathematically equivalent to the implementation I have provided, and our testing has shown that the corrected version performs better.

@bghira
Copy link

bghira commented Sep 18, 2023

i can confirm this after having discussed it with Tian, one of the original paper authors.

additionally, i've implemented the fix in SimpleTuner, as a non-conditional fix for v_prediction type models when min-snr gamma is in use.

@O-J1
Copy link

O-J1 commented Oct 15, 2023

May be my lack of knowledge but I had always wondered why for my dataset setting SNR seemed to yield worse results in a way I couldnt really explain. Glad to know I wasnt imagining things. Hope this can be solved soon

@kohya-ss kohya-ss added the help wanted Extra attention is needed label Oct 22, 2023
@drhead
Copy link
Author

drhead commented Nov 27, 2023

Fixed with merge of #934

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

7 participants