Skip to content

Commit

Permalink
Update sd_schedulers.py
Browse files Browse the repository at this point in the history
Implement Karras Exponential Scheduler for Enhanced Sigma Blending in Diffusion Models

This commit introduces the Karras Exponential Scheduler function, which blends sigma values generated by the Karras and Exponential scheduling methods. The primary objective is to enhance the noise schedule during the diffusion process, improving image generation quality, convergence speed, and maintaining structural details, particularly for img2img tasks.

What is the Karras Exponential Scheduler?
The Karras Exponential Scheduler is a method that dynamically blends two distinct scheduling approaches—Karras and Exponential—to create a combined noise schedule for diffusion models.

Karras Scheduling: A method known for its ability to generate high-quality images by spacing noise levels in a way that emphasizes low-noise areas, leading to more detailed and coherent images.

Exponential Scheduling: Emphasizes a more gradual reduction of noise, providing a smoother transition that can help in areas where the model requires more flexibility to explore variations.
By blending these two methods, the scheduler creates a noise schedule that leverages the strengths of both approaches, leading to better overall performance during image generation.

Why Use the Karras Exponential Scheduler?
Improved Convergence: The blended approach allows the model to converge more quickly on high-quality results, reducing artifacts commonly seen in traditional img2img and initial image generation tasks.

Enhanced Structural Preservation: By combining Karras's focus on detail preservation with Exponential’s smoother transitions, this scheduler helps maintain critical features, such as faces, hands, and complex object structures, which are often lost in purely exponential or linear scheduling.

Dynamic Blending for Flexibility: The scheduler employs a dynamic blend factor that adjusts the influence of each method throughout the diffusion process, providing a more adaptive approach to image generation.

Sharpening Mechanism: An additional sharpening factor is applied to adjust the impact of low noise levels, fine-tuning the result and preventing over-smoothing of fine details.

Key Changes Introduced:
karras_exponential_scheduler Function: Blends Karras and Exponential sigma sequences based on a dynamic blend factor that adjusts during the diffusion process.
Sharpening Feature: Applies conditional sharpening to low-value sigmas to maintain image sharpness and prevent loss of detail.
Error Handling: Robust error handling with fallback options ensures stable performance even when unexpected input issues occur.

Benefits:
Better Quality: Results in cleaner, more coherent images with less distortion, particularly noticeable in img2img scenarios.
Increased Control: Provides additional tuning options via blend_factor and sharpen_factor, allowing for fine-grained adjustments based on the specific task.

Stability: Error handling ensures the model continues to function even when generation challenges arise, offering fallback sigmas to maintain consistency.

Usage:
This scheduler can be directly integrated into diffusion pipelines and used as an alternative to existing sigma scheduling methods, particularly where traditional schedules struggle with img2img or high-detail preservation.

This addition aims to improve the overall robustness and versatility of diffusion models, making them more effective in real-world applications requiring high detail retention and image consistency.
  • Loading branch information
unicornsyay authored Sep 21, 2024
1 parent 82a973c commit 02b0092
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions modules/sd_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ def ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):
return torch.FloatTensor(sigs).to(device)


def karras_exponential_scheduler(n, sigma_min, sigma_max, device, blend_factor=0.3, sharpen_factor=0.9):
# Optional: Adjust sigma_max to fine-tune the range of noise levels
# Example: Increase by 10%; modify the multiplier as needed (e.g., 1.1 for 10% increase)
sigma_max = sigma_max * 1.1 # Adjust this multiplier as needed, e.g., 1.1 for a 10% increase
# Initialize sigmas to None to avoid UnboundLocalError in case of failure during assignment
sigmas_karras, sigmas_exponential = None, None
try:
# Generate sigma schedules using Karras and Exponential methods
# These functions are from the k_diffusion module and are crucial for generating the noise schedule
sigmas_karras = k_diffusion.sampling.get_sigmas_karras(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
sigmas_exponential = k_diffusion.sampling.get_sigmas_exponential(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)

# Print the lengths of the generated sequences for debugging purposes
#print(f"Length before resampling: Karras - {len(sigmas_karras)}, Exponential - {len(sigmas_exponential)}")

# Check if lengths are different; resample if necessary to match lengths
if len(sigmas_karras) != len(sigmas_exponential):
# Resample both sigmas to match the longer sequence length; ensures consistent blending
max_length = max(len(sigmas_karras), len(sigmas_exponential))
sigmas_karras = resample_sigmas(sigmas_karras, max_length, device)
sigmas_exponential = resample_sigmas(sigmas_exponential, max_length, device)

except Exception as e:
# Handle errors during sigma generation; assign fallback empty tensors if an error occurs
print(f"Error generating sigmas: {e}")
sigmas_karras = torch.zeros(n).to(device)
sigmas_exponential = torch.zeros(n).to(device)

# Ensure sigmas have been assigned correctly; raise an error if not
if sigmas_karras is None or sigmas_exponential is None:
raise ValueError("Failed to generate or assign sigmas correctly.")

# Create a linear tensor from 0 to 1 to represent progress over the length of sigmas
progress = torch.linspace(0, 1, len(sigmas_karras)).to(device)
# Calculate a dynamic blend factor that decreases from blend_factor to 0
dynamic_blend_factor = (1 - progress) * blend_factor
# Blend the Karras and Exponential sigmas based on the dynamic blend factor
sigs = (sigmas_karras * (1 - dynamic_blend_factor) + sigmas_exponential * dynamic_blend_factor)

# Trim the blended sigmas if they exceed the required number of steps
if len(sigs) > n:
sigs = sigs[:n]

# Apply sharpening to sigmas below a certain threshold to enhance sharpness
# Modify sharpen_factor to adjust sharpening intensity
sharpen_mask = torch.where(sigs < sigma_min * 1.5, sharpen_factor, 1.0).to(device)
sigs = sigs * sharpen_mask

# Return the final blended and adjusted sigmas on the specified device
return sigs.to(device)


def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
# From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """
alpha = shared.opts.beta_dist_alpha
Expand All @@ -140,6 +192,8 @@ def beta_scheduler(n, sigma_min, sigma_max, inner_model, device):
Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),
Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),
Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),
Scheduler('karras_exponential', 'Karras Exponential', karras_exponential_scheduler),

]

schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

0 comments on commit 02b0092

Please sign in to comment.