Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve hires fix efficiency; fix: better SDXL hires fix default implementation; fix: cascade bug #287

Merged
merged 12 commits into from
Jul 21, 2024
Merged
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ repos:
types-tabulate,
types-tqdm,
types-urllib3,
horde_sdk==0.12.0,
horde_sdk==0.14.0,
horde_model_reference==0.8.1,
]
7 changes: 2 additions & 5 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,8 @@ def __init__(
# Load our pipelines
self._load_pipelines()

stdio = OutputCollector()
with contextlib.redirect_stdout(stdio):
# Load our custom nodes
self._load_custom_nodes()
stdio.replay()
# Load our custom nodes
self._load_custom_nodes()

self._comfyui_callback = comfyui_callback

Expand Down
89 changes: 69 additions & 20 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from enum import Enum, auto
from types import FunctionType

from horde_model_reference.meta_consts import STABLE_DIFFUSION_BASELINE_CATEGORY, get_baseline_native_resolution
from horde_sdk.ai_horde_api.apimodels import ImageGenerateJobPopResponse
from horde_sdk.ai_horde_api.apimodels.base import (
GenMetadataEntry,
Expand Down Expand Up @@ -78,15 +79,55 @@ def __init__(
self.faults = faults


def _calc_upscale_sampler_steps(payload):
"""Calculates the amount of hires_fix upscaler steps based on the denoising used and the steps used for the
primary image"""
upscale_steps = round(payload["ddim_steps"] * (0.9 - payload["hires_fix_denoising_strength"]))
if upscale_steps < 3:
upscale_steps = 3
def _calc_upscale_sampler_steps(
payload: dict,
) -> int:
"""Use `ImageUtils.calc_upscale_sampler_steps(...)` to calculate the number of steps for the upscale sampler.

logger.debug(f"Upscale steps calculated as {upscale_steps}")
return upscale_steps
Args:
payload (dict): The payload to use for the calculation.

Returns:
int: The number of steps to use.
"""
model_name = payload.get("model_name")
baseline = None
native_resolution = None
if model_name is not None:
baseline = SharedModelManager.model_reference_manager.stable_diffusion.get_model_baseline(model_name)
if baseline is not None:
try:
baseline = STABLE_DIFFUSION_BASELINE_CATEGORY(baseline)
except ValueError:
baseline = None
logger.warning(
f"Model {model_name} has an invalid baseline {baseline} so we cannot calculate "
"hires fix upscale steps.",
)
if baseline is not None:
native_resolution = get_baseline_native_resolution(baseline)

width: int | None = payload.get("width")
height: int | None = payload.get("height")
hires_fix_denoising_strength: float | None = payload.get("hires_fix_denoising_strength")
ddim_steps: int | None = payload.get("ddim_steps")

if width is None or height is None:
raise ValueError("Width and height must be set to calculate upscale sampler steps")

if hires_fix_denoising_strength is None:
raise ValueError("Hires fix denoising strength must be set to calculate upscale sampler steps")

if ddim_steps is None:
raise ValueError("DDIM steps must be set to calculate upscale sampler steps")

return ImageUtils.calc_upscale_sampler_steps(
model_native_resolution=native_resolution,
width=width,
height=height,
hires_fix_denoising_strength=hires_fix_denoising_strength,
ddim_steps=ddim_steps,
)


class HordeLib:
Expand Down Expand Up @@ -825,13 +866,15 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
raise RuntimeError(f"Invalid key {key}")
elif "*" in key:
key, multiplier = key.split("*", 1)
elif key in payload:

if key in payload:
if multiplier:
pipeline_params[newkey] = round(payload.get(key) * float(multiplier))
else:
pipeline_params[newkey] = payload.get(key)
else:
elif not isinstance(key, FunctionType):
logger.error(f"Parameter {key} not found")

# We inject these parameters to ensure the HordeCheckpointLoader knows what file to load, if necessary
# We don't want to hardcode this into the pipeline.json as we export this directly from ComfyUI
# and don't want to have to rememebr to re-add those keys
Expand Down Expand Up @@ -874,16 +917,22 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
baseline = None
if model_details:
baseline = model_details.get("baseline")
if baseline and (baseline == "stable_cascade" or baseline == "stable_diffusion_xl"):
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
else:
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)
if baseline:
if baseline == "stable_cascade":
new_width, new_height = ImageUtils.get_first_pass_image_resolution_max(
original_width,
original_height,
)
elif baseline == "stable_diffusion_xl":
new_width, new_height = ImageUtils.get_first_pass_image_resolution_sdxl(
original_width,
original_height,
)
else: # fall through case; only `stable diffusion 1`` at time of writing
new_width, new_height = ImageUtils.get_first_pass_image_resolution_min(
original_width,
original_height,
)

# This is the *target* resolution
pipeline_params["latent_upscale.width"] = original_width
Expand Down
3 changes: 2 additions & 1 deletion hordelib/initialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def initialise(
force_normal_vram_mode: bool = True,
extra_comfyui_args: list[str] | None = None,
disable_smart_memory: bool = False,
do_not_load_model_mangers: bool = False,
):
"""Initialise hordelib. This is required before using any other hordelib functions.

Expand Down Expand Up @@ -96,7 +97,7 @@ def initialise(
# Initialise model manager
from hordelib.shared_model_manager import SharedModelManager

SharedModelManager()
SharedModelManager(do_not_load_model_mangers=do_not_load_model_mangers)

sys.argv = sys_arg_bkp

Expand Down
17 changes: 1 addition & 16 deletions hordelib/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,7 @@ def load_model_database(self) -> None:
)

def download_model_reference(self) -> dict:
try:
logger.debug(f"Downloading Model Reference for {self.models_db_name}")
response = requests.get(self.remote_db)
logger.debug("Downloaded Model Reference successfully")
models = response.json()
logger.info("Updated Model Reference from remote.")
return models
except Exception as e: # XXX Double check and/or rework this
logger.error(
f"Download failed: {e}",
)
logger.warning("Model Reference not downloaded, using local copy")
if self.models_db_path.exists():
return json.loads(self.models_db_path.read_text())
logger.error("No local copy of Model Reference found!")
return {}
raise NotImplementedError("Downloading model databases is no longer supported within hordelib.")

def get_free_ram_mb(self) -> int:
"""Returns the amount of free RAM in MB rounded down to the nearest integer.
Expand Down
2 changes: 1 addition & 1 deletion hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class LoraModelManager(BaseModelManager):
)
LORA_API = "https://civitai.com/api/v1/models?types=LORA&sort=Highest%20Rated&primaryFileOnly=true"
MAX_RETRIES = 10 if not TESTS_ONGOING else 3
MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 15
MAX_DOWNLOAD_THREADS = 5 if not TESTS_ONGOING else 75
RETRY_DELAY = 3 if not TESTS_ONGOING else 0.2
"""The time to wait between retries in seconds"""
REQUEST_METADATA_TIMEOUT = 20 # Longer because civitai performs poorly on metadata requests for more than 5 models
Expand Down
1 change: 1 addition & 0 deletions hordelib/nodes/facerestore_cf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import model_management
import numpy as np
import torch

# from comfy_extras.chainner_models import model_loading
from hordelib.nodes.facerestore_cf.r_chainner import model_loading
from torchvision.transforms.functional import normalize
Expand Down
28 changes: 7 additions & 21 deletions hordelib/nodes/facerestore_cf/r_chainner/gfpganv1_clean_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,12 @@ def forward(
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
noise = [getattr(self.noises, f"noise{i}") for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
Expand All @@ -96,9 +92,7 @@ def forward(
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)

# main generation
Expand Down Expand Up @@ -160,14 +154,10 @@ def __init__(self, in_channels, out_channels, mode="down"):
def forward(self, x):
out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
# upsample/downsample
out = F.interpolate(
out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
out = F.interpolate(out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False)
out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
# skip
x = F.interpolate(
x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
)
x = F.interpolate(x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False)
skip = self.skip(x)
out = out + skip
return out
Expand Down Expand Up @@ -283,9 +273,7 @@ def __init__(
# load pre-trained stylegan2 model if necessary
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
torch.load(decoder_load_path, map_location=lambda storage, loc: storage)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
Expand Down Expand Up @@ -317,9 +305,7 @@ def __init__(
)
self.load_state_dict(state_dict)

def forward(
self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
):
def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs):
"""Forward function for GFPGANv1Clean.
Args:
x (Tensor): Input images.
Expand Down
6 changes: 1 addition & 5 deletions hordelib/nodes/facerestore_cf/r_chainner/model_loading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
from hordelib.nodes.facerestore_cf.r_chainner.types import PyTorchModel

Expand All @@ -21,9 +20,6 @@ def load_state_dict(state_dict) -> PyTorchModel:
state_dict_keys = list(state_dict.keys())

# GFPGAN
if (
"toRGB.0.weight" in state_dict_keys
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
):
if "toRGB.0.weight" in state_dict_keys and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys:
model = GFPGANv1Clean(state_dict)
return model
28 changes: 7 additions & 21 deletions hordelib/nodes/facerestore_cf/r_chainner/stylegan2_clean_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def forward(self, x, style):
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)

weight = weight.view(
b * self.out_channels, c, self.kernel_size, self.kernel_size
)
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)

# upsample or downsample if necessary
if self.sample_mode == "upsample":
Expand Down Expand Up @@ -224,9 +222,7 @@ def forward(self, x, style, skip=None):
out = out + self.bias
if skip is not None:
if self.upsample:
skip = F.interpolate(
skip, scale_factor=2, mode="bilinear", align_corners=False
)
skip = F.interpolate(skip, scale_factor=2, mode="bilinear", align_corners=False)
out = out + skip
return out

Expand Down Expand Up @@ -257,9 +253,7 @@ class StyleGAN2GeneratorClean(nn.Module):
narrow (float): Narrow ratio for channels. Default: 1.0.
"""

def __init__(
self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
):
def __init__(self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1):
super(StyleGAN2GeneratorClean, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
Expand Down Expand Up @@ -362,9 +356,7 @@ def get_latent(self, x):
return self.style_mlp(x)

def mean_latent(self, num_latent):
latent_in = torch.randn(
num_latent, self.num_style_feat, device=self.constant_input.weight.device
)
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent

Expand Down Expand Up @@ -398,16 +390,12 @@ def forward(
if randomize_noise:
noise = [None] * self.num_layers # for each style conv layer
else: # use the stored noise
noise = [
getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
]
noise = [getattr(self.noises, f"noise{i}") for i in range(self.num_layers)]
# style truncation
if truncation < 1:
style_truncation = []
for style in styles:
style_truncation.append(
truncation_latent + truncation * (style - truncation_latent)
)
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
styles = style_truncation
# get style latents with injection
if len(styles) == 1:
Expand All @@ -422,9 +410,7 @@ def forward(
if inject_index is None:
inject_index = random.randint(1, self.num_latent - 1)
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = (
styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
)
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
latent = torch.cat([latent1, latent2], 1)

# main generation
Expand Down
4 changes: 2 additions & 2 deletions hordelib/nodes/facerestore_cf/r_chainner/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import Union

from hordelib.nodes.facerestore_cf.r_chainner.gfpganv1_clean_arch import GFPGANv1Clean
Expand All @@ -11,7 +10,8 @@
def is_pytorch_face_model(model: object):
return isinstance(model, PyTorchFaceModels)

PyTorchModels = (*PyTorchFaceModels, )

PyTorchModels = (*PyTorchFaceModels,)
PyTorchModel = Union[PyTorchFaceModel]


Expand Down
Loading
Loading