Skip to content

Commit

Permalink
Add CFG++ samplers and GuidanceLimiter
Browse files Browse the repository at this point in the history
  • Loading branch information
pamparamm committed Jul 10, 2024
1 parent 3336b62 commit b49dafb
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 8 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- master
paths:
- "pyproject.toml"

jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
# if this is a forked repository. Skipping the workflow.
if: github.event.repository.fork == false
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
## Add your own personal access token to your Github Repository secrets and reference it here.
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ Modified implementation of AttentionCouple by [laksjdjf](https://github.com/laks

You can use multiple `LatentToMaskBB` nodes to set bounding box masks for `AttentionCouplePPM`. The parameters are relative to your initial latent: `x=0.5, y=0.0, w=0.5, h=1.0` will produce a mask covering right half of your image.

## CFG++SamplerSelect
Samplers adapted to [CFG++: Manifold-constrained Classifier Free Guidance for Diffusion Models by Chung et al.](https://cfgpp-diffusion.github.io/)

Should greatly reduce overexposure effect. Use together with `SamplerCustom` node. Don't forget to set CFG scale to 1.0-2.0 and PAG scale (if used) to 0.5-1.0.

## Guidance Limiter
Implementation of [Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models by Kynkäänniemi et al.](https://arxiv.org/abs/2404.07724) (also contains `RescaleCFG` functionality)

## Empty Latent Image (Aspect Ratio)
`Empty Latent Image (Aspect Ratio)` node generates empty latent with specified aspect ratio and with respect to target resolution.

Expand All @@ -29,4 +37,6 @@ Counts tokens in your prompt and returns them as a string. You can also print to
# Hooks/Hijacks

## Schedulers
Adds [AlignYourSteps scheduler modified by Extraltodeus](https://github.com/Extraltodeus/sigmas_tools_and_the_golden_scheduler/blob/0dc89a264ef346a093d053c0da751f3ece317613/sigmas_merge.py#L203-L233) to the default list of schedulers by replacing `comfy.samplers.calculate_sigmas` function. `ays` is the default AYS scheduler and `ays+` is just `ays` with `force_sigma_min=True`.
Adds [AlignYourSteps scheduler modified by Extraltodeus](https://github.com/Extraltodeus/sigmas_tools_and_the_golden_scheduler/blob/0dc89a264ef346a093d053c0da751f3ece317613/sigmas_merge.py#L203-L233) to the default list of schedulers by replacing `comfy.samplers.calculate_sigmas` function. `ays` is the default AYS scheduler and `ays+` is just `ays` with `force_sigma_min=True`.

Also adds GITS scheduler and AYS_30 scheduler (based on [AYS_32 by Koitenshin](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15751#issuecomment-2143648234))
7 changes: 7 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .clip_misc import CLIPTextEncodeBREAK, CLIPMicroConditioning, CLIPTokenCounter
from .clip_negpip import CLIPNegPip
from .attention_couple_ppm import AttentionCouplePPM
from .guidance_limiter import GuidanceLimiter
from .samplers import CFGPPSamplerSelect, inject_samplers
from .schedulers import hijack_schedulers

WEB_DIRECTORY = "./js"
Expand All @@ -20,6 +22,8 @@
"CLIPTokenCounter": CLIPTokenCounter,
"CLIPNegPip": CLIPNegPip,
"AttentionCouplePPM": AttentionCouplePPM,
"Guidance Limiter": GuidanceLimiter,
"CFGPPSamplerSelect": CFGPPSamplerSelect,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -35,6 +39,9 @@
"CLIPTokenCounter": "CLIPTokenCounter",
"CLIPNegPip": "CLIPNegPip",
"AttentionCouplePPM": "AttentionCouplePPM",
"Guidance Limiter": "Guidance Limiter",
"CFGPPSamplerSelect": "CFG++SamplerSelect",
}

inject_samplers()
hijack_schedulers()
62 changes: 62 additions & 0 deletions guidance_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch


# Based on Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models by Kynkäänniemi et al.
class GuidanceLimiter:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"sigma_start": ("FLOAT", {"default": 5.42, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": 0.28, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"cfg_rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "model_patches"

def patch(self, model, sigma_start: float, sigma_end: float, cfg_rescale: float):
m = model.clone()

def limited_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
sigma = args["sigma"]
cond_scale = args["cond_scale"]

if sigma_start >= 0 and sigma[0] > sigma_start:
cond_scale = 1

if sigma_end >= 0 and sigma[0] <= sigma_end:
cond_scale = 1

if cfg_rescale > 0:
x_orig = args["input"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))

# rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)
cond = ((x - (x_orig - cond)) * (sigma**2 + 1.0) ** 0.5) / (sigma)
uncond = ((x - (x_orig - uncond)) * (sigma**2 + 1.0) ** 0.5) / (sigma)

# rescalecfg
x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True)

x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = cfg_rescale * x_rescaled + (1.0 - cfg_rescale) * x_cfg

return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)

cfg_result = uncond + (cond - uncond) * cond_scale

return cfg_result

m.set_model_sampler_cfg_function(limited_cfg)

return (m,)
106 changes: 106 additions & 0 deletions ppm_cfgpp_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from tqdm import trange

import comfy.model_patcher
from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler


@torch.no_grad()
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None

temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_d
old_denoised = denoised
return x


@torch.no_grad()
def sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
"""DPM-Solver++(2M) SDE."""
if len(sigmas) <= 1:
return x

if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')

seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])

old_denoised = None
h_last = None
h = None

temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Denoising step
x = denoised
else:
# DPM-Solver++(2M) SDE
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
h = s - t
eta_h = eta * h

x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * (x + (denoised - temp[0])) + (-h - eta_h).expm1().neg() * denoised

if old_denoised is not None:
r = h_last / h
if solver_type == 'heun':
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
elif solver_type == 'midpoint':
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)

if eta:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise

old_denoised = denoised
h_last = h
return x


@torch.no_grad()
def sample_dpmpp_2m_sde_gpu_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
return x

sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[project]
name = "comfyui-ppm"
description = "Fixed AttentionCouple/NegPip(negative weights in prompts), more CFG++ samplers, etc."
version = "1.0.0"
license = "AGPL-3.0"

[project.urls]
Repository = "https://github.com/pamparamm/ComfyUI-ppm"
# Used by Comfy Registry https://comfyregistry.org

[tool.comfy]
PublisherId = "pamparamm"
DisplayName = "ComfyUI-ppm"
Icon = ""
37 changes: 37 additions & 0 deletions samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.samplers import KSAMPLER
from . import ppm_cfgpp_sampling

INITIALIZED = False
CFGPP_SAMPLER_NAMES_ORIGINAL = ["euler_cfg_pp", "euler_ancestral_cfg_pp"]
CFGPP_SAMPLER_NAMES = CFGPP_SAMPLER_NAMES_ORIGINAL + ["dpmpp_2m_cfg_pp", "dpmpp_2m_sde_cfg_pp", "dpmpp_2m_sde_gpu_cfg_pp"]


def inject_samplers():
global INITIALIZED
if not INITIALIZED:
INITIALIZED = True


# More CFG++ samplers based on https://github.com/comfyanonymous/ComfyUI/pull/3871 by yoinked-h
class CFGPPSamplerSelect:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sampler_name": (CFGPP_SAMPLER_NAMES,),
}
}

RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"

FUNCTION = "get_sampler"

def get_sampler(self, sampler_name):
if sampler_name in CFGPP_SAMPLER_NAMES_ORIGINAL:
sampler_func = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
else:
sampler_func = getattr(ppm_cfgpp_sampling, "sample_{}".format(sampler_name))
sampler = KSAMPLER(sampler_func)
return (sampler,)
33 changes: 26 additions & 7 deletions schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,29 @@
import comfy.samplers
from comfy.samplers import SCHEDULER_NAMES, simple_scheduler
from comfy_extras.nodes_align_your_steps import loglinear_interp
from comfy_extras.nodes_gits import GITSScheduler


calculate_sigmas_original = comfy.samplers.calculate_sigmas

AYS_SCHEDULER = "ays"
AYS_PLUS_SCHEDULER = "ays+"
AYS_30_SCHEDULER = "ays_30"
AYS_30_PLUS_SCHEDULER = "ays_30+"
GITS_SCHEDULER = "gits"


# Modified AYS by Extraltodeus from https://github.com/Extraltodeus/sigmas_tools_and_the_golden_scheduler/blob/0dc89a264ef346a093d053c0da751f3ece317613/sigmas_merge.py#L203-L233
# 30-step AYS by Koitenshin from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15751#issuecomment-2143648234
def _ays_scheduler(model_sampling, steps, force_sigma_min, model_type="SDXL"):
timestep_indices = {
"SD1": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24, 0],
"SDXL": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13, 0],
"SVD": [995, 920, 811, 686, 555, 418, 315, 174, 109, 12, 0],
"SD1_30": [999, 954, 907, 855, 819, 782, 744, 713, 684, 654, 623, 592, 559, 528, 499, 470, 437, 401, 364, 328, 291, 256, 221, 183, 149, 114, 70, 41, 21, 4, 0],
"SDXL_30": [999, 953, 904, 850, 813, 777, 738, 695, 650, 602, 556, 510, 462, 417, 374, 331, 290, 250, 214, 182, 155, 131, 108, 85, 66, 49, 32, 20, 12, 3, 0],
}
indices = timestep_indices[model_type]
indices = [999 - i for i in indices]
indices.reverse()
sigmas = simple_scheduler(model_sampling, 1000)[indices]
sigmas = loglinear_interp(sigmas.tolist(), steps + 1 if not force_sigma_min else steps)
sigmas = torch.FloatTensor(sigmas)
Expand All @@ -25,17 +34,27 @@ def _ays_scheduler(model_sampling, steps, force_sigma_min, model_type="SDXL"):


def _calculate_sigmas_hijack(model_sampling, scheduler_name, steps):
if scheduler_name == "ays":
if scheduler_name == AYS_SCHEDULER:
sigmas = _ays_scheduler(model_sampling, steps, False)
elif scheduler_name == "ays+":
elif scheduler_name == AYS_PLUS_SCHEDULER:
sigmas = _ays_scheduler(model_sampling, steps, True)
elif scheduler_name == AYS_30_SCHEDULER:
sigmas = _ays_scheduler(model_sampling, steps, False, "SDXL_30")
elif scheduler_name == AYS_30_PLUS_SCHEDULER:
sigmas = _ays_scheduler(model_sampling, steps, True, "SDXL_30")
elif scheduler_name == GITS_SCHEDULER:
sigmas = GITSScheduler().get_sigmas(1.2, steps, 1.0)[0]
else:
sigmas = calculate_sigmas_original(model_sampling, scheduler_name, steps)
return sigmas


def hijack_schedulers():
SCHEDULER_NAMES.append("ays")
SCHEDULER_NAMES.append("ays+")
assert calculate_sigmas_original != _calculate_sigmas_hijack
if calculate_sigmas_original == _calculate_sigmas_hijack:
raise ValueError("Schedulers are already hijacked")
SCHEDULER_NAMES.append(AYS_SCHEDULER)
SCHEDULER_NAMES.append(AYS_PLUS_SCHEDULER)
SCHEDULER_NAMES.append(AYS_30_SCHEDULER)
SCHEDULER_NAMES.append(AYS_30_PLUS_SCHEDULER)
SCHEDULER_NAMES.append(GITS_SCHEDULER)
comfy.samplers.calculate_sigmas = _calculate_sigmas_hijack

0 comments on commit b49dafb

Please sign in to comment.