Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#178 from brkirch/fix-half-precision
Browse files Browse the repository at this point in the history
Fix support for half precision and `--upcast-sampling` on non-MPS devices, keep older web UI version support
  • Loading branch information
Mikubill authored Feb 18, 2023
2 parents de32133 + 03a8fc0 commit 5b53034
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
1 change: 1 addition & 0 deletions preload.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
def preload(parser):
parser.add_argument("--controlnet-dir", type=str, help="Path to directory with ControlNet models", default=None)
parser.add_argument("--no-half-controlnet", action='store_true', help="do not switch the ControlNet models to 16-bit floats (only needed without --no-half)", default=None)
25 changes: 6 additions & 19 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from omegaconf import OmegaConf
from modules import devices, lowvram, shared, scripts

cond_cast_unet = getattr(devices, 'cond_cast_unet', lambda x: x)

from ldm.util import exists
from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import conv_nd, linear, zero_module, timestep_embedding
Expand Down Expand Up @@ -113,9 +115,6 @@ def __init__(self, state_dict, config_path, weight=1.0, lowvram=False, base_mode
self.control_model.to(devices.get_device_for("controlnet"))

def hook(self, model, parent_model):
if devices.get_device_for("controlnet").type == 'mps':
from modules.devices import cond_cast_unet

outer = self

def guidance_schedule_handler(x):
Expand All @@ -137,10 +136,7 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
assert timesteps is not None, ValueError(f"insufficient timestep: {timesteps}")
hs = []
with th.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if devices.get_device_for("controlnet").type == 'mps':
t_emb = cond_cast_unet(t_emb)

t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
Expand Down Expand Up @@ -229,8 +225,7 @@ def __init__(
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
if devices.get_device_for("controlnet").type == 'mps':
use_fp16 = devices.dtype_unet == th.float16
use_fp16 = getattr(devices, 'dtype_unet', devices.dtype) == th.float16 and not shared.cmd_opts.no_half_controlnet

super().__init__()
if use_spatial_transformer:
Expand Down Expand Up @@ -447,18 +442,10 @@ def align(self, hint, h, w):
return hint

def forward(self, x, hint, timesteps, context, **kwargs):
if devices.get_device_for("controlnet").type == 'mps':
from modules.devices import cond_cast_unet

t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if devices.get_device_for("controlnet").type == 'mps':
t_emb = cond_cast_unet(t_emb)

t_emb = cond_cast_unet(timestep_embedding(timesteps, self.model_channels, repeat_only=False))
emb = self.time_embed(t_emb)
if devices.get_device_for("controlnet").type == 'mps':
hint = cond_cast_unet(hint)

guided_hint = self.input_hint_block(hint, emb, context)
guided_hint = self.input_hint_block(cond_cast_unet(hint), emb, context)
outs = []

h1, w1 = x.shape[-2:]
Expand Down

0 comments on commit 5b53034

Please sign in to comment.