Skip to content

Commit

Permalink
Add CFG++ Euler Dy/SMEA Dy samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
pamparamm committed Jul 19, 2024
1 parent 62eb791 commit 490ff66
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 41 deletions.
38 changes: 0 additions & 38 deletions js/attention_couple.js

This file was deleted.

38 changes: 38 additions & 0 deletions js/attention_couple_ppm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Taken from https://github.com/laksjdjf/cgem156-ComfyUI/blob/1f5533f7f31345bafe4b833cbee15a3c4ad74167/js/attention_couple.js
import { app } from "/scripts/app.js";

app.registerExtension({
name: "AttentionCouplePPM",
async beforeRegisterNodeDef(nodeType, nodeData) {
if (nodeData.name === "AttentionCouplePPM") {
const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions;
nodeType.prototype.getExtraMenuOptions = function (_, options) {
const r = origGetExtraMenuOptions?.apply?.(this, arguments);
options.unshift(
{
content: "Add Region",
callback: () => {
var index = 1;
if (this.inputs != undefined) {
index += this.inputs.length;
}
this.addInput("cond_" + Math.floor(index / 2), "CONDITIONING");
this.addInput("mask_" + Math.floor(index / 2), "MASK");
},
},
{
content: "Remove Region",
callback: () => {
if (this.inputs != undefined && this.inputs.at(-2)["type"] === "CONDITIONING") {
this.removeInput(this.inputs.length - 1);
this.removeInput(this.inputs.length - 1);
}
},
},
);
return r;

}
}
},
});
170 changes: 170 additions & 0 deletions ppm_cfgpp_dyn_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Modified samplers from Euler-Smea-Dyn-Sampler by Koishi-Star
from tqdm.auto import trange
import torch

from comfy.k_diffusion import sampling
import comfy.model_patcher


class _Rescaler:
def __init__(self, model, x, mode, **extra_args):
self.model = model
self.x = x
self.mode = mode
self.extra_args = extra_args

self.latent_image, self.noise = model.latent_image, model.noise
self.denoise_mask = self.extra_args.get("denoise_mask", None)

def __enter__(self):
if self.latent_image is not None:
self.model.latent_image = torch.nn.functional.interpolate(input=self.latent_image, size=self.x.shape[2:4], mode=self.mode)
if self.noise is not None:
self.model.noise = torch.nn.functional.interpolate(input=self.latent_image, size=self.x.shape[2:4], mode=self.mode)
if self.denoise_mask is not None:
self.extra_args["denoise_mask"] = torch.nn.functional.interpolate(input=self.denoise_mask, size=self.x.shape[2:4], mode=self.mode)

return self

def __exit__(self, type, value, traceback):
del self.model.latent_image, self.model.noise
self.model.latent_image, self.model.noise = self.latent_image, self.noise


@torch.no_grad()
def dy_sampling_step_cfg_pp(x, model, sigma_next, sigma_hat, **extra_args):
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

original_shape = x.shape
batch_size, channels, m, n = original_shape[0], original_shape[1], original_shape[2] // 2, original_shape[3] // 2
extra_row = x.shape[2] % 2 == 1
extra_col = x.shape[3] % 2 == 1

if extra_row:
extra_row_content = x[:, :, -1:, :]
x = x[:, :, :-1, :]
if extra_col:
extra_col_content = x[:, :, :, -1:]
x = x[:, :, :, :-1]

a_list = x.unfold(2, 2, 2).unfold(3, 2, 2).contiguous().view(batch_size, channels, m * n, 2, 2)
c = a_list[:, :, :, 1, 1].view(batch_size, channels, m, n)

with _Rescaler(model, c, 'nearest-exact', **extra_args) as rescaler:
denoised = model(c, sigma_hat * c.new_ones([c.shape[0]]), **rescaler.extra_args)
d = sampling.to_d(c, sigma_hat, temp[0])
c = denoised + d * sigma_next

d_list = c.view(batch_size, channels, m * n, 1, 1)
a_list[:, :, :, 1, 1] = d_list[:, :, :, 0, 0]
x = a_list.view(batch_size, channels, m, n, 2, 2).permute(0, 1, 2, 4, 3, 5).reshape(batch_size, channels, 2 * m, 2 * n)

if extra_row or extra_col:
x_expanded = torch.zeros(original_shape, dtype=x.dtype, device=x.device)
x_expanded[:, :, :2 * m, :2 * n] = x
if extra_row:
x_expanded[:, :, -1:, :2 * n + 1] = extra_row_content
if extra_col:
x_expanded[:, :, :2 * m, -1:] = extra_col_content
if extra_row and extra_col:
x_expanded[:, :, -1:, -1:] = extra_col_content[:, :, -1:, :]
x = x_expanded

return x


@torch.no_grad()
def sample_euler_dy_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1., s_gamma=None):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])

temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

for i in trange(len(sigmas) - 1, disable=disable):
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
if s_gamma is not None:
gamma = s_gamma
sigma_hat = sigmas[i] * (gamma + 1)
# print(sigma_hat)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, temp[0])
# Euler method
x = denoised + d * sigmas[i + 1]
if sigmas[i + 1] > 0:
if i // 2 == 1:
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
return x


@torch.no_grad()
def smea_sampling_step_cfg_pp(x, model, sigma_next, sigma_hat, **extra_args):
temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

m, n = x.shape[2], x.shape[3]
x = torch.nn.functional.interpolate(input=x, scale_factor=(1.25, 1.25), mode='nearest-exact')
with _Rescaler(model, x, 'nearest-exact', **extra_args) as rescaler:
denoised = model(x, sigma_hat * x.new_ones([x.shape[0]]), **rescaler.extra_args)
d = sampling.to_d(x, sigma_hat, temp[0])
x = denoised + d * sigma_next
x = torch.nn.functional.interpolate(input=x, size=(m,n), mode='nearest-exact')
return x


@torch.no_grad()
def sample_euler_smea_dy_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
s_tmax=float('inf'), s_noise=1.):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])

temp = [0]
def post_cfg_function(args):
temp[0] = args["uncond_denoised"]
return args["denoised"]

model_options = extra_args.get("model_options", {}).copy()
extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)

for i in trange(len(sigmas) - 1, disable=disable):
gamma = max(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
sigma_hat = sigmas[i] * (gamma + 1)
dt = sigmas[i + 1] - sigma_hat
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x - eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = sampling.to_d(x, sigma_hat, temp[0])
# Euler method
x = denoised + d * sigmas[i + 1]
if sigmas[i + 1] > 0:
if i + 1 // 2 == 1:
x = dy_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
if i + 1 // 2 == 0:
x = smea_sampling_step_cfg_pp(x, model, sigmas[i + 1], sigma_hat, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
return x
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-ppm"
description = "Fixed AttentionCouple/NegPip(negative weights in prompts), more CFG++ samplers, etc."
version = "1.0.2"
version = "1.0.3"
license = "AGPL-3.0"

[project.urls]
Expand Down
8 changes: 6 additions & 2 deletions samplers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.samplers import KSAMPLER
from . import ppm_cfgpp_sampling
from . import ppm_cfgpp_dyn_sampling

INITIALIZED = False
CFGPP_SAMPLER_NAMES_ORIGINAL = ["euler_cfg_pp", "euler_ancestral_cfg_pp"]
CFGPP_SAMPLER_NAMES_DYN = ["euler_dy_cfg_pp", "euler_smea_dy_cfg_pp"]
CFGPP_SAMPLER_NAMES = CFGPP_SAMPLER_NAMES_ORIGINAL + [
"dpmpp_2m_cfg_pp",
"dpmpp_2m_sde_cfg_pp",
"dpmpp_2m_sde_gpu_cfg_pp",
"dpmpp_3m_sde_cfg_pp",
"dpmpp_3m_sde_gpu_cfg_pp",
]
] + CFGPP_SAMPLER_NAMES_DYN


def inject_samplers():
Expand Down Expand Up @@ -38,8 +40,10 @@ def INPUT_TYPES(s):
def get_sampler(self, sampler_name, eta: float):
if sampler_name in CFGPP_SAMPLER_NAMES_ORIGINAL:
sampler_func = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
elif sampler_name in CFGPP_SAMPLER_NAMES_DYN:
sampler_func = getattr(ppm_cfgpp_dyn_sampling, "sample_{}".format(sampler_name))
else:
sampler_func = getattr(ppm_cfgpp_sampling, "sample_{}".format(sampler_name))
extra_options = {} if sampler_name in {"euler_cfg_pp", "dpmpp_2m_cfg_pp"} else {"eta": eta}
extra_options = {} if sampler_name in {"euler_cfg_pp", "dpmpp_2m_cfg_pp", "euler_dy_cfg_pp", "euler_smea_dy_cfg_pp"} else {"eta": eta}
sampler = KSAMPLER(sampler_func, extra_options=extra_options)
return (sampler,)

0 comments on commit 490ff66

Please sign in to comment.