-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CFG++ samplers and GuidanceLimiter
- Loading branch information
Showing
8 changed files
with
286 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: Publish to Comfy registry | ||
on: | ||
workflow_dispatch: | ||
push: | ||
branches: | ||
- master | ||
paths: | ||
- "pyproject.toml" | ||
|
||
jobs: | ||
publish-node: | ||
name: Publish Custom Node to registry | ||
runs-on: ubuntu-latest | ||
# if this is a forked repository. Skipping the workflow. | ||
if: github.event.repository.fork == false | ||
steps: | ||
- name: Check out code | ||
uses: actions/checkout@v4 | ||
- name: Publish Custom Node | ||
uses: Comfy-Org/publish-node-action@main | ||
with: | ||
## Add your own personal access token to your Github Repository secrets and reference it here. | ||
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
|
||
|
||
# Based on Applying Guidance in a Limited Interval Improves Sample and Distribution Quality in Diffusion Models by Kynkäänniemi et al. | ||
class GuidanceLimiter: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"model": ("MODEL",), | ||
"sigma_start": ("FLOAT", {"default": 5.42, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}), | ||
"sigma_end": ("FLOAT", {"default": 0.28, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}), | ||
"cfg_rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("MODEL",) | ||
FUNCTION = "patch" | ||
|
||
CATEGORY = "model_patches" | ||
|
||
def patch(self, model, sigma_start: float, sigma_end: float, cfg_rescale: float): | ||
m = model.clone() | ||
|
||
def limited_cfg(args): | ||
cond = args["cond"] | ||
uncond = args["uncond"] | ||
sigma = args["sigma"] | ||
cond_scale = args["cond_scale"] | ||
|
||
if sigma_start >= 0 and sigma[0] > sigma_start: | ||
cond_scale = 1 | ||
|
||
if sigma_end >= 0 and sigma[0] <= sigma_end: | ||
cond_scale = 1 | ||
|
||
if cfg_rescale > 0: | ||
x_orig = args["input"] | ||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) | ||
|
||
# rescale cfg has to be done on v-pred model output | ||
x = x_orig / (sigma * sigma + 1.0) | ||
cond = ((x - (x_orig - cond)) * (sigma**2 + 1.0) ** 0.5) / (sigma) | ||
uncond = ((x - (x_orig - uncond)) * (sigma**2 + 1.0) ** 0.5) / (sigma) | ||
|
||
# rescalecfg | ||
x_cfg = uncond + cond_scale * (cond - uncond) | ||
ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True) | ||
ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True) | ||
|
||
x_rescaled = x_cfg * (ro_pos / ro_cfg) | ||
x_final = cfg_rescale * x_rescaled + (1.0 - cfg_rescale) * x_cfg | ||
|
||
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) | ||
|
||
cfg_result = uncond + (cond - uncond) * cond_scale | ||
|
||
return cfg_result | ||
|
||
m.set_model_sampler_cfg_function(limited_cfg) | ||
|
||
return (m,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
from tqdm import trange | ||
|
||
import comfy.model_patcher | ||
from comfy.k_diffusion.sampling import BrownianTreeNoiseSampler | ||
|
||
|
||
@torch.no_grad() | ||
def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): | ||
"""DPM-Solver++(2M).""" | ||
extra_args = {} if extra_args is None else extra_args | ||
s_in = x.new_ones([x.shape[0]]) | ||
sigma_fn = lambda t: t.neg().exp() | ||
t_fn = lambda sigma: sigma.log().neg() | ||
old_denoised = None | ||
|
||
temp = [0] | ||
def post_cfg_function(args): | ||
temp[0] = args["uncond_denoised"] | ||
return args["denoised"] | ||
|
||
model_options = extra_args.get("model_options", {}).copy() | ||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) | ||
|
||
for i in trange(len(sigmas) - 1, disable=disable): | ||
denoised = model(x, sigmas[i] * s_in, **extra_args) | ||
if callback is not None: | ||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | ||
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) | ||
h = t_next - t | ||
if old_denoised is None or sigmas[i + 1] == 0: | ||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised | ||
else: | ||
h_last = t - t_fn(sigmas[i - 1]) | ||
r = h_last / h | ||
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised | ||
x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_d | ||
old_denoised = denoised | ||
return x | ||
|
||
|
||
@torch.no_grad() | ||
def sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): | ||
"""DPM-Solver++(2M) SDE.""" | ||
if len(sigmas) <= 1: | ||
return x | ||
|
||
if solver_type not in {'heun', 'midpoint'}: | ||
raise ValueError('solver_type must be \'heun\' or \'midpoint\'') | ||
|
||
seed = extra_args.get("seed", None) | ||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | ||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler | ||
extra_args = {} if extra_args is None else extra_args | ||
s_in = x.new_ones([x.shape[0]]) | ||
|
||
old_denoised = None | ||
h_last = None | ||
h = None | ||
|
||
temp = [0] | ||
def post_cfg_function(args): | ||
temp[0] = args["uncond_denoised"] | ||
return args["denoised"] | ||
|
||
model_options = extra_args.get("model_options", {}).copy() | ||
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) | ||
|
||
for i in trange(len(sigmas) - 1, disable=disable): | ||
denoised = model(x, sigmas[i] * s_in, **extra_args) | ||
if callback is not None: | ||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | ||
if sigmas[i + 1] == 0: | ||
# Denoising step | ||
x = denoised | ||
else: | ||
# DPM-Solver++(2M) SDE | ||
t, s = -sigmas[i].log(), -sigmas[i + 1].log() | ||
h = s - t | ||
eta_h = eta * h | ||
|
||
x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * (x + (denoised - temp[0])) + (-h - eta_h).expm1().neg() * denoised | ||
|
||
if old_denoised is not None: | ||
r = h_last / h | ||
if solver_type == 'heun': | ||
x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) | ||
elif solver_type == 'midpoint': | ||
x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) | ||
|
||
if eta: | ||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise | ||
|
||
old_denoised = denoised | ||
h_last = h | ||
return x | ||
|
||
|
||
@torch.no_grad() | ||
def sample_dpmpp_2m_sde_gpu_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): | ||
if len(sigmas) <= 1: | ||
return x | ||
|
||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() | ||
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler | ||
return sample_dpmpp_2m_sde_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[project] | ||
name = "comfyui-ppm" | ||
description = "Fixed AttentionCouple/NegPip(negative weights in prompts), more CFG++ samplers, etc." | ||
version = "1.0.0" | ||
license = "AGPL-3.0" | ||
|
||
[project.urls] | ||
Repository = "https://github.com/pamparamm/ComfyUI-ppm" | ||
# Used by Comfy Registry https://comfyregistry.org | ||
|
||
[tool.comfy] | ||
PublisherId = "pamparamm" | ||
DisplayName = "ComfyUI-ppm" | ||
Icon = "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from comfy.k_diffusion import sampling as k_diffusion_sampling | ||
from comfy.samplers import KSAMPLER | ||
from . import ppm_cfgpp_sampling | ||
|
||
INITIALIZED = False | ||
CFGPP_SAMPLER_NAMES_ORIGINAL = ["euler_cfg_pp", "euler_ancestral_cfg_pp"] | ||
CFGPP_SAMPLER_NAMES = CFGPP_SAMPLER_NAMES_ORIGINAL + ["dpmpp_2m_cfg_pp", "dpmpp_2m_sde_cfg_pp", "dpmpp_2m_sde_gpu_cfg_pp"] | ||
|
||
|
||
def inject_samplers(): | ||
global INITIALIZED | ||
if not INITIALIZED: | ||
INITIALIZED = True | ||
|
||
|
||
# More CFG++ samplers based on https://github.com/comfyanonymous/ComfyUI/pull/3871 by yoinked-h | ||
class CFGPPSamplerSelect: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"sampler_name": (CFGPP_SAMPLER_NAMES,), | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("SAMPLER",) | ||
CATEGORY = "sampling/custom_sampling/samplers" | ||
|
||
FUNCTION = "get_sampler" | ||
|
||
def get_sampler(self, sampler_name): | ||
if sampler_name in CFGPP_SAMPLER_NAMES_ORIGINAL: | ||
sampler_func = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) | ||
else: | ||
sampler_func = getattr(ppm_cfgpp_sampling, "sample_{}".format(sampler_name)) | ||
sampler = KSAMPLER(sampler_func) | ||
return (sampler,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters