diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 6b1a43630fa6..b9183e6d80c8 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -117,9 +117,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. @@ -154,7 +162,9 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, timestep_spacing: str = "linspace", @@ -258,6 +268,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -354,6 +370,19 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas + def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + def convert_model_output( self, model_output: torch.FloatTensor, @@ -787,8 +816,9 @@ def step( if self.step_index is None: self._init_step_index(timestep) - lower_order_final = ( - (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index fa8f362bd3b5..0168b8405391 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -117,6 +117,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. use_karras_sigmas (`bool`, *optional*, defaults to `False`): Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, the sigmas are determined according to a sequence of noise levels {σi}. @@ -154,6 +158,7 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, lambda_min_clipped: float = -float("inf"), variance_type: Optional[str] = None, @@ -804,8 +809,9 @@ def step( if self.step_index is None: self._init_step_index(timestep) - lower_order_final = ( - (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15) ) lower_order_second = ( (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 6e6442e0daf6..7fe71941b4e7 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -29,6 +29,7 @@ def get_scheduler_config(self, **kwargs): "algorithm_type": "dpmsolver++", "solver_type": "midpoint", "lower_order_final": False, + "euler_at_final": False, "lambda_min_clipped": -float("inf"), "variance_type": None, } @@ -195,6 +196,10 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_euler_at_final(self): + self.check_over_configs(euler_at_final=True) + self.check_over_configs(euler_at_final=False) + def test_lambda_min_clipped(self): self.check_over_configs(lambda_min_clipped=-float("inf")) self.check_over_configs(lambda_min_clipped=-5.1) @@ -258,6 +263,12 @@ def test_full_loop_with_karras_and_v_prediction(self): assert abs(result_mean.item() - 0.2096) < 1e-3 + def test_full_loop_with_lu_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_lu_lambdas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1554) < 1e-3 + def test_switch(self): # make sure that iterating over schedulers with same config names gives same results # for defaults