diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index bd8cb18c28..8971b4e6e9 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -1,6 +1,7 @@ import torch from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule import math +import numpy as np class EPS: def calculate_input(self, sigma, noise): @@ -69,12 +70,17 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) self.set_sigmas(sigmas) + self.set_alphas_cumprod(alphas_cumprod.float()) def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas.float()) self.register_buffer('log_sigmas', sigmas.log().float()) + def set_alphas_cumprod(self, alphas_cumprod): + self.register_buffer("alphas_cumprod", alphas_cumprod.float()) + @property def sigma_min(self): return self.sigmas[0] diff --git a/modules/patch_precision.py b/modules/patch_precision.py index 83569bdd15..22ffda0adf 100644 --- a/modules/patch_precision.py +++ b/modules/patch_precision.py @@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti self.linear_end = linear_end sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) self.set_sigmas(sigmas) + alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32) + self.set_alphas_cumprod(alphas_cumprod) return