From 1ae073c052c2a1eb2b3e87464a576451b3197326 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sat, 19 Oct 2024 06:53:19 -0400 Subject: [PATCH] Support SDXL v-pred models --- configs/sd_xl_v.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 7 +-- modules/sd_models_config.py | 4 ++ 3 files changed, 106 insertions(+), 3 deletions(-) create mode 100644 configs/sd_xl_v.yaml diff --git a/configs/sd_xl_v.yaml b/configs/sd_xl_v.yaml new file mode 100644 index 00000000000..c755dc74fda --- /dev/null +++ b/configs/sd_xl_v.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/sd_models.py b/modules/sd_models.py index 55bd9ca5e43..abe1c966c26 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False): return getattr(importlib.import_module(module, package=None), cls) -def load_model(checkpoint_info=None, already_loaded_state_dict=None): +def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None): from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + if not checkpoint_config: + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") @@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): if sd_model is not None: send_model_to_trash(sd_model) - load_model(checkpoint_info, already_loaded_state_dict=state_dict) + load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config) return model_data.sd_model try: diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index fb44c5a8d98..3c1e4a1518f 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxlv = os.path.join(sd_configs_path, "sd_xl_v.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") @@ -81,6 +82,9 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 9: return config_sdxl_inpainting else: + if ('v_pred' in sd): + del sd['v_pred'] + return config_sdxlv return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: