Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#13276 from woweenie/patch-1
Browse files Browse the repository at this point in the history
patch DDPM.register_betas so that users can put given_betas in model yaml
  • Loading branch information
AUTOMATIC1111 authored Sep 30, 2023
2 parents 7ce1f3a + d9d9414 commit e309583
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import re
import safetensors.torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf, ListConfig
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
Expand All @@ -17,6 +17,7 @@
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer
import tomesd
import numpy as np

model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
Expand Down Expand Up @@ -132,6 +133,7 @@ def setup_model():
os.makedirs(model_path, exist_ok=True)

enable_midas_autodownload()
patch_given_betas()


def checkpoint_tiles(use_short=False):
Expand Down Expand Up @@ -455,6 +457,17 @@ def load_model_wrapper(model_type):
midas.api.load_model = load_model_wrapper


def patch_given_betas():
original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
def patched_register_schedule(*args, **kwargs):
if args[1] is not None and isinstance(args[1], ListConfig):
modified_args = list(args) # Convert args tuple to a list
modified_args[1] = np.array(args[1]) # Modify the desired element
args = tuple(modified_args) # Convert the list back to a tuple
original_register_schedule(*args, **kwargs)
ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule


def repair_config(sd_config):

if not hasattr(sd_config.model.params, "use_ema"):
Expand Down

0 comments on commit e309583

Please sign in to comment.