Skip to content

Commit

Permalink
Stabilize DPM++, especially for SDXL and SDE-DPM++ (huggingface#5541)
Browse files Browse the repository at this point in the history
* stabilize dpmpp for sdxl by using euler at the final step

* add lu's uniform logsnr time steps

* add test

* fix check_copies

* fix tests

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
2 people authored and kashif committed Nov 11, 2023
1 parent 7cdeb79 commit 354536b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
34 changes: 32 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
Expand Down Expand Up @@ -154,7 +162,9 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
Expand Down Expand Up @@ -258,6 +268,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
Expand Down Expand Up @@ -354,6 +370,19 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Lu et al. (2022)."""

lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()

rho = 1.0 # 1.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas

def convert_model_output(
self,
model_output: torch.FloatTensor,
Expand Down Expand Up @@ -787,8 +816,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
Expand Down Expand Up @@ -154,6 +158,7 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
Expand Down Expand Up @@ -804,8 +809,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
Expand Down
11 changes: 11 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_scheduler_config(self, **kwargs):
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
"euler_at_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}
Expand Down Expand Up @@ -195,6 +196,10 @@ def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_euler_at_final(self):
self.check_over_configs(euler_at_final=True)
self.check_over_configs(euler_at_final=False)

def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
Expand Down Expand Up @@ -258,6 +263,12 @@ def test_full_loop_with_karras_and_v_prediction(self):

assert abs(result_mean.item() - 0.2096) < 1e-3

def test_full_loop_with_lu_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_lu_lambdas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.1554) < 1e-3

def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
Expand Down

0 comments on commit 354536b

Please sign in to comment.