Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#16567 from AUTOMATIC1111/feat/sdxl-v…
Browse files Browse the repository at this point in the history
…pred

Support and automatically detect SDXL V-prediction models
  • Loading branch information
AUTOMATIC1111 authored Oct 19, 2024
2 parents 907bfb5 + 1ae073c commit 8b19b75
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
98 changes: 98 additions & 0 deletions configs/sd_xl_v.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True

denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000

weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False

conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two

first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
7 changes: 4 additions & 3 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)


def load_model(checkpoint_info=None, already_loaded_state_dict=None):
def load_model(checkpoint_info=None, already_loaded_state_dict=None, checkpoint_config=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()

Expand All @@ -801,7 +801,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
else:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
if not checkpoint_config:
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)

timer.record("find config")
Expand Down Expand Up @@ -974,7 +975,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
if sd_model is not None:
send_model_to_trash(sd_model)

load_model(checkpoint_info, already_loaded_state_dict=state_dict)
load_model(checkpoint_info, already_loaded_state_dict=state_dict, checkpoint_config=checkpoint_config)
return model_data.sd_model

try:
Expand Down
4 changes: 4 additions & 0 deletions modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxlv = os.path.join(sd_configs_path, "sd_xl_v.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
Expand Down Expand Up @@ -81,6 +82,9 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
if ('v_pred' in sd):
del sd['v_pred']
return config_sdxlv
return config_sdxl

if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
Expand Down

0 comments on commit 8b19b75

Please sign in to comment.