Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Deepcache Optimization #14210

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions extensions-builtin/deepcache/deepcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional
from collections import defaultdict

import torch
from ldm.modules.diffusionmodules.openaimodel import timestep_embedding
from scripts.forward_timestep_embed_patch import forward_timestep_embed

from logging import getLogger
@dataclass
class DeepCacheParams:
cache_in_level: int = 0
cache_enable_step: int = 0
full_run_step_rate: int = 5
# cache_latents_cpu: bool = False
# cache_latents_hires: bool = False

class DeepCacheSession:
"""
Session for DeepCache, which holds cache data and provides functions for hooking the model.
"""
def __init__(self) -> None:
self.CACHE_LAST = {"timestep": {0}}
self.stored_forward = None
self.unet_reference = None
self.cache_success_count = 0
self.cache_fail_count = 0
self.fail_reasons = defaultdict(int)
self.success_reasons = defaultdict(int)
self.enumerated_timestep = {"value": -1}

def log_skip(self, reason:str = 'disabled_by_default'):
self.fail_reasons[reason] += 1
self.cache_fail_count += 1

def report(self):
# report cache success rate
total = self.cache_success_count + self.cache_fail_count
if total == 0:
return
logger = getLogger()
level = logger.getEffectiveLevel()
logger.log(level, "DeepCache Information :")
for fail_reasons, count in self.fail_reasons.items():
logger.log(level, f" {fail_reasons}: {count}")
for success_reasons, count in self.success_reasons.items():
logger.log(level, f" {success_reasons}: {count}")

def deepcache_hook_model(self, unet, params:DeepCacheParams):
"""
Hooks the given unet model to use DeepCache.
"""
caching_level = params.cache_in_level
# caching level 0 = no caching, idx for resnet layers
cache_enable_step = params.cache_enable_step
full_run_step_rate = params.full_run_step_rate # '5' means run full model every 5 steps
if full_run_step_rate < 1:
print(f"DeepCache disabled due to full_run_step_rate {full_run_step_rate} < 1 but enabled by user")
return # disabled
if getattr(unet, '_deepcache_hooked', False):
return # already hooked
CACHE_LAST = self.CACHE_LAST
self.stored_forward = unet.forward
self.enumerated_timestep["value"] = -1
valid_caching_in_level = min(caching_level, len(unet.input_blocks) - 1)
valid_caching_out_level = min(valid_caching_in_level, len(unet.output_blocks) - 1)
# set to max if invalid
caching_level = valid_caching_out_level
valid_cache_timestep_range = 50 # total 1000, 50
def put_cache(h:torch.Tensor, timestep:int, real_timestep:float):
"""
Registers cache
"""
CACHE_LAST["timestep"].add(timestep)
assert h is not None, "Cannot cache None"
# maybe move to cpu and load later for low vram?
CACHE_LAST["last"] = h
CACHE_LAST[f"timestep_{timestep}"] = h
CACHE_LAST["real_timestep"] = real_timestep
def get_cache(current_timestep:int, real_timestep:float) -> Optional[torch.Tensor]:
"""
Returns the cached tensor for the given timestep and cache key.
"""
if current_timestep < cache_enable_step:
self.fail_reasons['disabled'] += 1
self.cache_fail_count += 1
return None
elif full_run_step_rate < 1:
self.fail_reasons['full_run_step_rate_disabled'] += 1
self.cache_fail_count += 1
return None
elif current_timestep % full_run_step_rate == 0:
if f"timestep_{current_timestep}" in CACHE_LAST:
self.cache_success_count += 1
self.success_reasons['cached_exact'] += 1
CACHE_LAST["last"] = CACHE_LAST[f"timestep_{current_timestep}"] # update last
return CACHE_LAST[f"timestep_{current_timestep}"]
self.fail_reasons['full_run_step_rate_division'] += 1
self.cache_fail_count += 1
return None
elif CACHE_LAST.get("real_timestep", 0) + valid_cache_timestep_range < real_timestep:
self.fail_reasons['cache_outdated'] += 1
self.cache_fail_count += 1
return None
# check if cache exists
if "last" in CACHE_LAST:
self.success_reasons['cached_last'] += 1
self.cache_success_count += 1
return CACHE_LAST["last"]
self.fail_reasons['not_cached'] += 1
self.cache_fail_count += 1
return None
def hijacked_unet_forward(x, timesteps=None, context=None, y=None, **kwargs):
cache_cond = lambda : self.enumerated_timestep["value"] % full_run_step_rate == 0 or self.enumerated_timestep["value"] > cache_enable_step
use_cache_cond = lambda : self.enumerated_timestep["value"] > cache_enable_step and self.enumerated_timestep["value"] % full_run_step_rate != 0
nonlocal CACHE_LAST
assert (y is not None) == (
hasattr(unet, 'num_classes') and unet.num_classes is not None #v2 or xl
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, unet.model_channels, repeat_only=False).to(unet.dtype)
emb = unet.time_embed(t_emb)
if hasattr(unet, 'num_classes') and unet.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + unet.label_emb(y)
real_timestep = timesteps[0].item()
h = x.type(unet.dtype)
cached_h = get_cache(self.enumerated_timestep["value"], real_timestep)
for id, module in enumerate(unet.input_blocks):
self.log_skip('run_before_cache_input_block')
h = forward_timestep_embed(module, h, emb, context)
hs.append(h)
if cached_h is not None and use_cache_cond() and id == caching_level:
break
if not use_cache_cond():
self.log_skip('run_before_cache_middle_block')
h = forward_timestep_embed(unet.middle_block, h, emb, context)
relative_cache_level = len(unet.output_blocks) - caching_level - 1
for idx, module in enumerate(unet.output_blocks):
if cached_h is not None and use_cache_cond() and idx == relative_cache_level:
# use cache
h = cached_h
elif cache_cond() and idx == relative_cache_level:
# put cache
put_cache(h, self.enumerated_timestep["value"], real_timestep)
elif cached_h is not None and use_cache_cond() and idx < relative_cache_level:
# skip, h is already cached
continue
hsp = hs.pop()
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = forward_timestep_embed(module, h, emb, context, output_shape=output_shape)
h = h.type(x.dtype)
self.enumerated_timestep["value"] += 1
if unet.predict_codebook_ids:
return unet.id_predictor(h)
else:
return unet.out(h)
unet.forward = hijacked_unet_forward
unet._deepcache_hooked = True
self.unet_reference = unet

def detach(self):
if self.unet_reference is None:
return
if not getattr(self.unet_reference, '_deepcache_hooked', False):
return
# detach
self.unet_reference.forward = self.stored_forward
self.unet_reference._deepcache_hooked = False
self.unet_reference = None
self.stored_forward = None
self.CACHE_LAST.clear()
self.cache_fail_count = self.cache_success_count = 0#
78 changes: 78 additions & 0 deletions extensions-builtin/deepcache/scripts/deepcache_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from modules import scripts, script_callbacks, shared, processing
from deepcache import DeepCacheSession, DeepCacheParams
from scripts.deepcache_xyz import add_axis_options

class ScriptDeepCache(scripts.Script):

name = "DeepCache"
session: DeepCacheSession = None

def title(self):
return self.name

def show(self, is_img2img):
return scripts.AlwaysVisible

def get_deepcache_params(self, steps: int, enable_step_at:int = None) -> DeepCacheParams:
return DeepCacheParams(
cache_in_level=shared.opts.deepcache_cache_resnet_level,
cache_enable_step=int(shared.opts.deepcache_cache_enable_step_percentage * steps) if enable_step_at is None else enable_step_at,
full_run_step_rate=shared.opts.deepcache_full_run_step_rate,
)

def process_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
print("DeepCache process")
self.detach_deepcache()
if shared.opts.deepcache_enable:
self.configure_deepcache(self.get_deepcache_params(p.steps))

def before_hr(self, p:processing.StableDiffusionProcessing, *args):
print("DeepCache before_hr")
if self.session is not None:
self.session.enumerated_timestep["value"] = -1 # reset enumerated timestep
if not shared.opts.deepcache_hr_reuse:
self.detach_deepcache()
if shared.opts.deepcache_enable:
hr_steps = getattr(p, 'hr_second_pass_steps', 0) or p.steps
enable_step = int(shared.opts.deepcache_cache_enable_step_percentage_hr * hr_steps)
self.configure_deepcache(self.get_deepcache_params(getattr(p, 'hr_second_pass_steps', 0) or p.steps, enable_step_at = enable_step)) # use second pass steps if available

def postprocess_batch(self, p:processing.StableDiffusionProcessing, *args, **kwargs):
print("DeepCache postprocess")
self.detach_deepcache()

def configure_deepcache(self, params:DeepCacheParams):
if self.session is None:
self.session = DeepCacheSession()
self.session.deepcache_hook_model(
shared.sd_model.model.diffusion_model, #unet_model
params
)

def detach_deepcache(self):
print("Detaching DeepCache")
if self.session is None:
return
self.session.report()
self.session.detach()
self.session = None

def on_ui_settings():
import gradio as gr
options = {
"deepcache_explanation": shared.OptionHTML("""
<a href='https://github.com/horseee/DeepCache'>DeepCache</a> optimizes by caching the results of mid-blocks, which is known for high level features, and reusing them in the next forward pass.
"""),
"deepcache_enable": shared.OptionInfo(False, "Enable DeepCache").info("noticeable change in details of the generated picture"),
"deepcache_cache_resnet_level": shared.OptionInfo(0, "Cache Resnet level", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("Deeper = fewer layers cached"),
"deepcache_cache_enable_step_percentage": shared.OptionInfo(0.4, "Deepcaches is enabled after the step percentage", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache"),
"deepcache_full_run_step_rate": shared.OptionInfo(5, "Refreshes caches when step is divisible by number", gr.Slider, {"minimum": 0, "maximum": 1000, "step": 1}).info("5 = refresh caches every 5 steps"),
"deepcache_hr_reuse" : shared.OptionInfo(False, "Reuse for HR").info("Reuses cache information for HR generation"),
"deepcache_cache_enable_step_percentage_hr" : shared.OptionInfo(0.0, "Deepcaches is enabled after the step percentage for HR", gr.Slider, {"minimum": 0, "maximum": 1}).info("Percentage of initial steps to disable deepcache for HR generation"),
}
for name, opt in options.items():
opt.section = ('deepcache', "DeepCache")
shared.opts.add_option(name, opt)

script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_before_ui(add_axis_options)
64 changes: 64 additions & 0 deletions extensions-builtin/deepcache/scripts/deepcache_xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from modules import scripts
from modules.shared import opts

xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module

def int_applier(value_name:str, min_range:int = -1, max_range:int = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = int(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_int(p, x, xs):
validate(value_name, x)
opts.data[value_name] = int(x)
return apply_int

def bool_applier(value_name:str):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
assert value.lower() in ["true", "false"], f"Value {value} for {value_name} must be either true or false"
def apply_bool(p, x, xs):
validate(value_name, x)
value_boolean = x.lower() == "true"
opts.data[value_name] = value_boolean
return apply_bool

def float_applier(value_name:str, min_range:float = -1, max_range:float = -1):
"""
Returns a function that applies the given value to the given value_name in opts.data.
"""
def validate(value_name:str, value:str):
value = float(value)
# validate value
if not min_range == -1:
assert value >= min_range, f"Value {value} for {value_name} must be greater than or equal to {min_range}"
if not max_range == -1:
assert value <= max_range, f"Value {value} for {value_name} must be less than or equal to {max_range}"
def apply_float(p, x, xs):
validate(value_name, x)
opts.data[value_name] = float(x)
return apply_float

def add_axis_options():
extra_axis_options = [
xyz_grid.AxisOption("[DeepCache] Enabled", str, bool_applier("deepcache_enable"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[DeepCache] Cache Resnet level", int, int_applier("deepcache_cache_resnet_level", 0, 10)),
xyz_grid.AxisOption("[DeepCache] Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage", 0, 1)),
xyz_grid.AxisOption("[DeepCache] Cache Refresh Rate", int, int_applier("deepcache_full_run_step_rate", 0, 1000)),
xyz_grid.AxisOption("[DeepCache] HR Reuse", str, bool_applier("deepcache_hr_reuse"), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[DeepCache] HR Cache Disable initial step percentage", float, float_applier("deepcache_cache_enable_step_percentage_hr", 0, 1)),
]
set_a = {opt.label for opt in xyz_grid.axis_options}
set_b = {opt.label for opt in extra_axis_options}
if set_a.intersection(set_b):
return

xyz_grid.axis_options.extend(extra_axis_options)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Patched forward_timestep_embed function to support the following:
@source https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/modules/diffusionmodules/openaimodel.py
"""
from ldm.modules.attention import SpatialTransformer
try:
from ldm.modules.attention import SpatialVideoTransformer
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformer = None
from ldm.modules.diffusionmodules.openaimodel import TimestepBlock, TimestepEmbedSequential, Upsample
try:
from ldm.modules.diffusionmodules.openaimodel import VideoResBlock
except (ImportError, ModuleNotFoundError):
VideoResBlock = None

# SD XL modules from generative-models repo
from sgm.modules.attention import SpatialTransformer as SpatialTransformerSGM
try:
from sgm.modules.attention import SpatialVideoTransformer as SpatialVideoTransformerSGM
except (ImportError, ModuleNotFoundError):
SpatialVideoTransformerSGM = None
from sgm.modules.diffusionmodules.openaimodel import TimestepBlock as TimestepBlockSGM, Upsample as UpsampleSGM
try:
from sgm.modules.diffusionmodules.openaimodel import VideoResBlock as VideoResBlockSGM
except (ImportError, ModuleNotFoundError):
VideoResBlockSGM = None

import torch.nn.functional as F

def forward_timestep_embed(ts:TimestepEmbedSequential, x, emb, context=None, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
if VideoResBlock and isinstance(layer, (VideoResBlock, VideoResBlockSGM)):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, (TimestepBlock, TimestepBlockSGM)):
x = layer(x, emb)
elif SpatialVideoTransformer and isinstance(layer, (SpatialVideoTransformer, SpatialVideoTransformerSGM)):
x = layer(x, context, time_context, num_video_frames, image_only_indicator)
elif isinstance(layer, (SpatialTransformer, SpatialTransformerSGM)):
x = layer(x, context)
elif isinstance(layer, (Upsample, UpsampleSGM)):
x = forward_upsample(layer, x, output_shape=output_shape)
else:
x = layer(x)
return x

def forward_upsample(self:Upsample, x, output_shape=None):
assert x.shape[1] == self.channels
if self.dims == 3:
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]

x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x