Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize DPM++, especially for SDXL and SDE-DPM++ #5541

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
euler_at_final (`bool`, defaults to `False`):
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
Expand Down Expand Up @@ -154,7 +162,9 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. should we default euler_at_final to be True? you mentioned there is a tradeoff in image details but I think artifacts are much more undesirable
  2. I think we should deprecate lower_order_final now we have euler_at_final- we can strongly recommend to set euler_at_final to be True when using less than 15 steps

Copy link
Contributor

@patrickvonplaten patrickvonplaten Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we default euler_at_final to be True? you mentioned there is a tradeoff in image details but I think artifacts are much more undesirable

I would leave euler_at_final as False since the current default setting works great for SDv15. SDXL seems to be more of a special case here

I think we should deprecate lower_order_final now we have euler_at_final- we can strongly recommend to set euler_at_final to be True when using less than 15 steps

Don't think we should deprecate it. Already for backwards compatibility reasons we'll need to keep the two and again I think we should be careful to not destroy a functioning, well-working scheduler setting for SDv15.

Essentially, what is done here is to give the user a possibility to have lower_order_final=True even for models where we use more than 15 inference steps, but this should not come at the expense of breaking existing workflows, so I don't think we can default euler_at_final to True or that we can remove lower_order_final

Instead of adding euler_at_final, we could add a parameter, called enable_lower_order_below: int = 15 that we allow user to set to 1000 for SDXL. But I'm not sure that this is easier to understand / cleaner actually. So ok to leave as is for me!

use_karras_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
Expand Down Expand Up @@ -258,6 +268,12 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_lu_lambdas:
lambdas = np.flip(log_sigmas.copy())
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
Expand Down Expand Up @@ -354,6 +370,19 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas

def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Lu et al. (2022)."""

lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()

rho = 1.0 # 1.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = lambda_min ** (1 / rho)
max_inv_rho = lambda_max ** (1 / rho)
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas

def convert_model_output(
self,
model_output: torch.FloatTensor,
Expand Down Expand Up @@ -787,8 +816,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,9 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)

lower_order_final = (
(self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
# Improve numerical stability for small number of steps
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
Expand Down
11 changes: 11 additions & 0 deletions tests/schedulers/test_scheduler_dpm_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_scheduler_config(self, **kwargs):
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
"euler_at_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}
Expand Down Expand Up @@ -195,6 +196,10 @@ def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)

def test_euler_at_final(self):
self.check_over_configs(euler_at_final=True)
self.check_over_configs(euler_at_final=False)

def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
Expand Down Expand Up @@ -258,6 +263,12 @@ def test_full_loop_with_karras_and_v_prediction(self):

assert abs(result_mean.item() - 0.2096) < 1e-3

def test_full_loop_with_lu_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_lu_lambdas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.1554) < 1e-3

def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
Expand Down
Loading