diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index b0e945a12dc..a0a98bbcb03 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -308,8 +308,10 @@ def parse_generation_parameters(x: str): ('UniPC skip type', 'uni_pc_skip_type'), ('UniPC order', 'uni_pc_order'), ('UniPC lower order final', 'uni_pc_lower_order_final'), + ('Token merging ratio', 'token_merging_ratio'), + ('Token merging ratio hr', 'token_merging_ratio_hr'), ('RNG', 'randn_source'), - ('NGMS', 's_min_uncond'), + ('NGMS', 's_min_uncond') ] diff --git a/modules/processing.py b/modules/processing.py index f902b9df969..94fe2625cd7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -29,6 +29,13 @@ from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +import tomesd + +# add a logger for the processing module +logger = logging.getLogger(__name__) +# manually set output level here since there is no option to do so yet through launch options +# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s') + # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -471,6 +478,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) + enable_hr = getattr(p, 'enable_hr', False) generation_params = { "Steps": p.steps, @@ -489,6 +497,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, + "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio, + "Token merging ratio hr": None if not enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, @@ -522,9 +532,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() + if opts.token_merging_ratio > 0: + sd_models.apply_token_merging(sd_model=p.sd_model, hr=False) + logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'") + res = process_images_inner(p) finally: + # undo model optimizations made by tomesd + if opts.token_merging_ratio > 0: + tomesd.remove_patch(p.sd_model) + logger.debug('Token merging model optimizations removed') + # restore opts to original state if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): @@ -977,8 +996,22 @@ def save_intermediate(image, index): x = None devices.torch_gc() + # apply token merging optimizations from tomesd for high-res pass + if opts.token_merging_ratio_hr > 0: + # in case the user has used separate merge ratios + if opts.token_merging_ratio > 0: + tomesd.remove_patch(self.sd_model) + logger.debug('Adjusting token merging ratio for high-res pass') + + sd_models.apply_token_merging(sd_model=self.sd_model, hr=True) + logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'") + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0: + tomesd.remove_patch(self.sd_model) + logger.debug('Removed token merging optimizations from model') + self.is_hr_pass = False return samples diff --git a/modules/sd_models.py b/modules/sd_models.py index 3316d021184..4c9a0a1fb70 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -17,6 +17,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer +import tomesd model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) @@ -578,3 +579,25 @@ def unload_model_weights(sd_model=None, info=None): print(f"Unloaded weights {timer.summary()}.") return sd_model + + +def apply_token_merging(sd_model, hr: bool): + """ + Applies speed and memory optimizations from tomesd. + + Args: + hr (bool): True if called in the context of a high-res pass + """ + + ratio = shared.opts.token_merging_ratio + if hr: + ratio = shared.opts.token_merging_ratio_hr + + tomesd.apply_patch( + sd_model, + ratio=ratio, + use_rand=False, # can cause issues with some samplers + merge_attn=True, + merge_crossattn=False, + merge_mlp=False + ) diff --git a/modules/shared.py b/modules/shared.py index 96a20a6bd9d..a5e8d0bd019 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -350,6 +350,8 @@ def list_samplers(): "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}), + "token_merging_ratio_hr": OptionInfo(0, "Merging Ratio (high-res pass)", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}), + "token_merging_ratio": OptionInfo(0, "Merging Ratio", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}) })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -458,6 +460,7 @@ def list_samplers(): "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) + options_templates.update() diff --git a/requirements_versions.txt b/requirements_versions.txt index 0a276b0b4b8..0e03deedc0c 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,3 +26,4 @@ torchsde==0.2.5 safetensors==0.3.1 httpcore<=0.15 fastapi==0.94.0 +tomesd>=0.1.2 \ No newline at end of file