Skip to content

Commit

Permalink
Merge pull request #15815 from AUTOMATIC1111/torch-float64-or-float32
Browse files Browse the repository at this point in the history
fix soft inpainting on mps and xpu, torch_utils.float64
  • Loading branch information
AUTOMATIC1111 authored Jun 8, 2024
2 parents 6450d24 + f015b94 commit b4723bb
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
9 changes: 4 additions & 5 deletions extensions-builtin/soft-inpainting/scripts/soft_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
from modules.ui_components import InputAccordion
import modules.scripts as scripts
from modules.torch_utils import float64


class SoftInpaintingSettings:
Expand Down Expand Up @@ -79,13 +80,11 @@ def latent_blend(settings, a, b, t):

# Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
# 64-bit operations are used here to allow large exponents.
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(float64(image_interp)).add_(0.00001)

# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
settings.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
settings.inpaint_detail_preservation) * t3
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(float64(a)).pow_(settings.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(float64(b)).pow_(settings.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
del a_magnitude, b_magnitude, t3, one_minus_t3
Expand Down
5 changes: 3 additions & 2 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from modules import shared
from modules.models.diffusion.uni_pc import uni_pc
from modules.torch_utils import float64


@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

Expand Down Expand Up @@ -43,7 +44,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)

extra_args = {} if extra_args is None else extra_args
Expand Down
9 changes: 9 additions & 0 deletions modules/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import torch.nn
import torch


def get_param(model) -> torch.nn.Parameter:
Expand All @@ -15,3 +16,11 @@ def get_param(model) -> torch.nn.Parameter:
return param

raise ValueError(f"No parameters found in model {model!r}")


def float64(t: torch.Tensor):
"""return torch.float64 if device is not mps or xpu, else return torch.float32"""
match t.device.type:
case 'mps', 'xpu':
return torch.float32
return torch.float64

0 comments on commit b4723bb

Please sign in to comment.