Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate optional speed and memory improvements by token merging (via dbolya/tomesd) #9256

Merged
merged 21 commits into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion modules/generation_parameters_copypaste.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,10 @@ def parse_generation_parameters(x: str):
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('Token merging ratio', 'token_merging_ratio'),
('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
('NGMS', 's_min_uncond')
]


Expand Down
33 changes: 33 additions & 0 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@

from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
import tomesd

# add a logger for the processing module
logger = logging.getLogger(__name__)
# manually set output level here since there is no option to do so yet through launch options
# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s')


# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
Expand Down Expand Up @@ -471,6 +478,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
index = position_in_batch + iteration * p.batch_size

clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
enable_hr = getattr(p, 'enable_hr', False)

generation_params = {
"Steps": p.steps,
Expand All @@ -489,6 +497,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
Expand Down Expand Up @@ -522,9 +532,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()

if opts.token_merging_ratio > 0:
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'")

res = process_images_inner(p)

finally:
# undo model optimizations made by tomesd
if opts.token_merging_ratio > 0:
tomesd.remove_patch(p.sd_model)
logger.debug('Token merging model optimizations removed')

# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
Expand Down Expand Up @@ -977,8 +996,22 @@ def save_intermediate(image, index):
x = None
devices.torch_gc()

# apply token merging optimizations from tomesd for high-res pass
if opts.token_merging_ratio_hr > 0:
# in case the user has used separate merge ratios
if opts.token_merging_ratio > 0:
tomesd.remove_patch(self.sd_model)
logger.debug('Adjusting token merging ratio for high-res pass')

sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'")

samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0:
tomesd.remove_patch(self.sd_model)
logger.debug('Removed token merging optimizations from model')

self.is_hr_pass = False

return samples
Expand Down
23 changes: 23 additions & 0 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd

model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
Expand Down Expand Up @@ -578,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")

return sd_model


def apply_token_merging(sd_model, hr: bool):
"""
Applies speed and memory optimizations from tomesd.

Args:
hr (bool): True if called in the context of a high-res pass
"""

ratio = shared.opts.token_merging_ratio
if hr:
ratio = shared.opts.token_merging_ratio_hr

tomesd.apply_patch(
sd_model,
ratio=ratio,
use_rand=False, # can cause issues with some samplers
merge_attn=True,
merge_crossattn=False,
merge_mlp=False
)
3 changes: 3 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ def list_samplers():
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
"token_merging_ratio_hr": OptionInfo(0, "Merging Ratio (high-res pass)", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}),
"token_merging_ratio": OptionInfo(0, "Merging Ratio", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1})
}))

options_templates.update(options_section(('compatibility', "Compatibility"), {
Expand Down Expand Up @@ -458,6 +460,7 @@ def list_samplers():
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
}))


options_templates.update()


Expand Down
1 change: 1 addition & 0 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ torchsde==0.2.5
safetensors==0.3.1
httpcore<=0.15
fastapi==0.94.0
tomesd>=0.1.2