Skip to content

Commit

Permalink
Add a TomePatchModel node to the _for_testing section.
Browse files Browse the repository at this point in the history
Tome increases sampling speed at the expense of quality.
  • Loading branch information
comfyanonymous committed Mar 31, 2023
1 parent 7e68278 commit 18a6c1d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 11 deletions.
15 changes: 13 additions & 2 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import model_management

from . import tomesd

if model_management.xformers_enabled():
import xformers
Expand Down Expand Up @@ -508,8 +509,18 @@ def forward(self, x, context=None, transformer_options={}):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)

def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
n = self.norm1(x)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
else:
n = self.attn1(n, context=context if self.disable_self_attn else None)

x += n
n = self.norm2(x)
n = self.attn2(n, context=context)

x += n
x = self.ff(self.norm3(x)) + x
return x

Expand Down
117 changes: 117 additions & 0 deletions comfy/ldm/modules/tomesd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@


import torch
from typing import Tuple, Callable
import math

def do_nothing(x: torch.Tensor, mode:str=None):
return x


def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Args:
- metric [B, N, C]: metric to use for similarity
- w: image width in tokens
- h: image height in tokens
- sx: stride in the x dimension for dst, must divide w
- sy: stride in the y dimension for dst, must divide h
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
"""
B, N, _ = metric.shape

if r <= 0:
return do_nothing, do_nothing

with torch.no_grad():

hsy, wsx = h // sy, w // sx

# For each sy by sx kernel, randomly assign one token to be dst and the rest src
idx_buffer = torch.zeros(1, hsy, wsx, sy*sx, 1, device=metric.device)

if no_rand:
rand_idx = torch.zeros(1, hsy, wsx, 1, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(1, hsy, wsx, 1, 1), device=metric.device)

idx_buffer.scatter_(dim=3, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=idx_buffer.dtype))
idx_buffer = idx_buffer.view(1, hsy, wsx, sy, sx, 1).transpose(2, 3).reshape(1, N, 1)
rand_idx = idx_buffer.argsort(dim=1)

num_dst = int((1 / (sx*sy)) * N)
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst

def split(x):
C = x.shape[-1]
src = x.gather(dim=1, index=a_idx.expand(B, N - num_dst, C))
dst = x.gather(dim=1, index=b_idx.expand(B, num_dst, C))
return src, dst

metric = metric / metric.norm(dim=-1, keepdim=True)
a, b = split(metric)
scores = a @ b.transpose(-1, -2)

# Can't reduce more than the # tokens in src
r = min(a.shape[1], r)

node_max, node_idx = scores.max(dim=-1)
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
src_idx = edge_idx[..., :r, :] # Merged Tokens
dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape

unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

return torch.cat([unm, dst], dim=1)

def unmerge(x: torch.Tensor) -> torch.Tensor:
unm_len = unm_idx.shape[1]
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
_, _, c = unm.shape

src = dst.gather(dim=-2, index=dst_idx.expand(B, r, c))

# Combine back to the original shape
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
out.scatter_(dim=-2, index=a_idx.expand(B, a_idx.shape[1], 1).gather(dim=1, index=src_idx).expand(B, r, c), src=src)

return out

return merge, unmerge


def get_functions(x, ratio, original_shape):
b, c, original_h, original_w = original_shape
original_tokens = original_h * original_w
downsample = int(math.sqrt(original_tokens // x.shape[1]))
stride_x = 2
stride_y = 2
max_downsample = 1

if downsample <= max_downsample:
w = original_w // downsample
h = original_h // downsample
r = int(x.shape[1] * ratio)
no_rand = True
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
return m, u

nothing = lambda y: y
return nothing, nothing
17 changes: 9 additions & 8 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def cond_cat(c_list):
out['c_concat'] = [torch.cat(c_concat)]
return out

def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0

Expand Down Expand Up @@ -195,7 +195,7 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot


max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale


Expand All @@ -212,20 +212,20 @@ def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out


class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None:
out *= denoise_mask

Expand Down Expand Up @@ -333,7 +333,7 @@ class KSampler:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]

def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
Expand All @@ -353,6 +353,7 @@ def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=N
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options

def _calculate_sigmas(self, steps):
sigmas = None
Expand Down Expand Up @@ -421,7 +422,7 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N
else:
precision_scope = contextlib.nullcontext

extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}

cond_concat = None
if hasattr(self.model, 'concat_keys'):
Expand Down
9 changes: 9 additions & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import contextlib
import copy

import sd1_clip
import sd2_clip
Expand Down Expand Up @@ -274,12 +275,20 @@ def __init__(self, model):
self.model = model
self.patches = []
self.backup = {}
self.model_options = {"transformer_options":{}}

def clone(self):
n = ModelPatcher(self.model)
n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
return n

def set_model_tomesd(self, ratio):
self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio}

def model_dtype(self):
return self.model.diffusion_model.dtype

def add_patches(self, patches, strength=1.0):
p = {}
model_sd = self.model.state_dict()
Expand Down
19 changes: 18 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
return (model_lora, clip_lora)

class TomePatchModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, ratio):
m = model.clone()
m.set_model_tomesd(ratio)
return (m, )

class VAELoader:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -646,7 +662,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
model_management.load_controlnet_gpu(control_net_models)

if sampler_name in comfy.samplers.KSampler.SAMPLERS:
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
else:
#other samplers
pass
Expand Down Expand Up @@ -1016,6 +1032,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
"CLIPVisionLoader": CLIPVisionLoader,
"VAEDecodeTiled": VAEDecodeTiled,
"VAEEncodeTiled": VAEEncodeTiled,
"TomePatchModel": TomePatchModel,
}

def load_custom_node(module_path):
Expand Down

2 comments on commit 18a6c1d

@catboxanon
Copy link

@catboxanon catboxanon commented on 18a6c1d Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to modify this implementation based on the author's comments here: AUTOMATIC1111/stable-diffusion-webui#9256 (comment)

@comfyanonymous
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly how I implemented it though.

Please sign in to comment.