Skip to content

Commit

Permalink
Merge pull request #220 from Haidra-Org/main
Browse files Browse the repository at this point in the history
feat: img2img on Stable Cascade, update comfyui to `40e124c6`; fix: minor bug fixes
  • Loading branch information
tazlin authored Mar 20, 2024
2 parents a81596b + 8db05ff commit 9ee3518
Show file tree
Hide file tree
Showing 24 changed files with 1,554 additions and 75 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.2.0
rev: 24.3.0
hooks:
- id: black
exclude: ^hordelib/nodes/.*\..*$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.3.3
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.8.0'
rev: 'v1.9.0'
hooks:
- id: mypy
exclude: ^examples/.*$ # FIXME
Expand Down
1 change: 1 addition & 0 deletions examples/kudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class KudosModel:
"inpainting",
"outpainting",
"txt2img",
"remix",
]

def __init__(self, model_filename=None):
Expand Down
12 changes: 7 additions & 5 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def _fix_node_names(self, data: dict, design: dict) -> dict:
# pace. This is why the only thing that partially relies on that format, is in fact, optional.
def _patch_pipeline(self, data: dict, design: dict) -> dict:
"""Patch the pipeline data with the design data."""
# FIXME: This can now be done through the _meta.title key included with each API export.
# First replace comfyui standard types with hordelib node types
data = self._fix_pipeline_types(data)
# Now try to find better parameter names
Expand Down Expand Up @@ -689,11 +690,12 @@ def _run_pipeline(
# This is useful for dumping the entire pipeline to the terminal when
# developing and debugging new pipelines. A badly structured pipeline
# file just results in a cryptic error from comfy
if False: # This isn't here, Tazlin :)
pretty_pipeline = pformat(pipeline)
logger.warning(pretty_pipeline)
with open("pipeline_debug.json", "w") as outfile:
outfile.write(json.dumps(pipeline, indent=4))
# if False: # This isn't here, Tazlin :)
# with open("pipeline_debug.json", "w") as outfile:
# default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>"
# outfile.write(json.dumps(pipeline, indent=4, default=default))
# pretty_pipeline = pformat(pipeline)
# logger.warning(pretty_pipeline)

# The client_id parameter here is just so we receive comfy callbacks for debugging.
# We pretend we are a web client and want async callbacks.
Expand Down
1 change: 1 addition & 0 deletions hordelib/config_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ def set_system_path() -> None:
"""Adds ComfyUI to the python path, as it is not a proper library."""
comfyui_path = get_comfyui_path()
sys.path.append(str(comfyui_path))
sys.path.append(str(comfyui_path) + "/comfy")
sys.path.append(str(get_hordelib_path() / "nodes"))
2 changes: 1 addition & 1 deletion hordelib/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from hordelib.config_path import get_hordelib_path

COMFYUI_VERSION = "2a813c3b09292c9aeab622ddf65d77e5d8171d0d"
COMFYUI_VERSION = "0c55f16c9e66eaa4915e288b34e4f848fb2d949f"
"""The exact version of ComfyUI version to load."""

REMOTE_PROXY = ""
Expand Down
72 changes: 71 additions & 1 deletion hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class HordeLib:
"hough": "control_mlsd_fp16.safetensors",
}

SOURCE_IMAGE_PROCESSING_OPTIONS = ["img2img", "inpainting", "outpainting"]
SOURCE_IMAGE_PROCESSING_OPTIONS = ["img2img", "inpainting", "outpainting", "remix"]

SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform", "sgm_uniform", "exponential"]

Expand Down Expand Up @@ -170,6 +170,12 @@ class HordeLib:
"model_name": {"datatype": str, "default": "stable_diffusion"}, # Used internally by hordelib
"stable_cascade_stage_b": {"datatype": str, "default": None}, # Stable Cascade
"stable_cascade_stage_c": {"datatype": str, "default": None}, # Stable Cascade
"extra_source_images": {"datatype": list, "default": []}, # Stable Cascade Remix
}

EXTRA_IMAGES_SCHEMA = {
"image": {"datatype": Image.Image, "default": None},
"strength": {"datatype": float, "min": 0.0, "max": 5.0, "default": 1.0},
}

LORA_SCHEMA = {
Expand Down Expand Up @@ -219,6 +225,8 @@ class HordeLib:
"stable_cascade_empty_latent_image.width": "width",
"stable_cascade_empty_latent_image.height": "height",
"stable_cascade_empty_latent_image.batch_size": "n_iter",
"sc_image_loader.image": "source_image",
"sc_image_loader_0.image": "source_image",
"sampler_stage_c.sampler_name": "sampler_name",
"sampler_stage_b.sampler_name": "sampler_name",
"sampler_stage_c.cfg": "cfg_scale",
Expand Down Expand Up @@ -338,6 +346,12 @@ def _validate_data_structure(self, data, schema_definition=PAYLOAD_SCHEMA):
# Remove invalid tis
data["tis"] = [x for x in data["tis"] if x.get("name")]

# Do the same for extra images, if we have them in this data structure
if data.get("extra_source_images"):
for i, img in enumerate(data.get("extra_source_images")):
data["extra_source_images"][i] = self._validate_data_structure(img, HordeLib.EXTRA_IMAGES_SCHEMA)
data["extra_source_images"] = [x for x in data["extra_source_images"] if x.get("image")]

return data

def _apply_aihorde_compatibility_hacks(self, payload: dict) -> tuple[dict, list[GenMetadataEntry]]:
Expand Down Expand Up @@ -802,6 +816,58 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
# the source image instead of the latent noise generator
if pipeline_params.get("image_loader.image"):
self.generator.reconnect_input(pipeline_data, "sampler.latent_image", "vae_encode")
if pipeline_params.get("sc_image_loader.image"):
self.generator.reconnect_input(
pipeline_data,
"sampler_stage_c.latent_image",
"stablecascade_stagec_vaeencode",
)
self.generator.reconnect_input(
pipeline_data,
"sampler_stage_b.latent_image",
"stablecascade_stagec_vaeencode",
)

# If we have a remix request, we check for extra images to add to the pipeline
if payload.get("source_processing") == "remix":
logger.debug([payload.get("source_image"), payload.get("extra_source_images")])
for image_index in range(len(payload.get("extra_source_images", []))):
# The first image is always taken from the source_image param
# That will sit on the 0 spot.
# Therefore we want the extra images to start iterating from 1
extra_image = payload["extra_source_images"][image_index]["image"]
extra_image_strength = payload["extra_source_images"][image_index].get("strength", 1)
node_index = image_index + 1
pipeline_data[f"sc_image_loader_{node_index}"] = {
"inputs": {"image": extra_image, "upload": "image"},
"class_type": "HordeImageLoader",
}
pipeline_data[f"clip_vision_encode_{node_index}"] = {
"inputs": {
"clip_vision": ["model_loader_stage_c", 3],
"image": [f"sc_image_loader_{node_index}", 0],
},
"class_type": "CLIPVisionEncode",
}
pipeline_data[f"unclip_conditioning_{node_index}"] = {
"inputs": {
"strength": extra_image_strength,
"noise_augmentation": 0,
# Each conditioning ingests the conditioning before it like a chain
"conditioning": [f"unclip_conditioning_{node_index-1}", 0],
"clip_vision_output": [f"clip_vision_encode_{node_index}", 0],
},
"class_type": "unCLIPConditioning",
}

# The last extra image always connects to the stage_c sampler positive prompt
if image_index == len(payload.get("extra_source_images")) - 1:
self.generator.reconnect_input(
pipeline_data,
"sampler_stage_c.positive",
f"unclip_conditioning_{node_index}",
)

return pipeline_params, faults

def _get_appropriate_pipeline(self, params):
Expand All @@ -818,11 +884,15 @@ def _get_appropriate_pipeline(self, params):
# controlnet_annotator
# image_facefix
# image_upscale
# stable_cascade
# stable_cascade_remix

# controlnet, controlnet_hires_fix controlnet_annotator
if params.get("model_name"):
model_details = SharedModelManager.manager.compvis.get_model_reference_info(params["model_name"])
if model_details.get("baseline") == "stable_cascade":
if params.get("source_processing") == "remix":
return "stable_cascade_remix"
return "stable_cascade"
if params.get("control_type"):
if params.get("return_control_map", False):
Expand Down
32 changes: 18 additions & 14 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,10 +1153,12 @@ def get_lora_last_use(self, lora_name, is_version: bool = False):
return None
return datetime.strptime(lora["versions"][version]["last_used"], "%Y-%m-%d %H:%M:%S")

def fetch_adhoc_lora(self, lora_name, timeout=45, is_version: bool = False) -> str | None:
"""Checks if a LoRa is available
If not, immediately downloads it
Finally, returns the lora name"""
def fetch_adhoc_lora(self, lora_name, timeout: int | None = 45, is_version: bool = False) -> str | None:
"""Checks if a LoRa is available. If not, waits for it to be downloaded to complete *if timeout is set*.
- If timeout is set, it will wait for the download to complete and return the lora name.
- If timeout is *not* set, it will begin a download thread and return `None` immediately.
"""
if is_version and not isinstance(lora_name, int) and not lora_name.isdigit():
logger.debug("Lora version requested, but lora name is not an integer")
return None
Expand Down Expand Up @@ -1199,16 +1201,18 @@ def fetch_adhoc_lora(self, lora_name, timeout=45, is_version: bool = False) -> s
self._touch_lora(lora_name, False)
return fuzzy_find
self._download_lora(lora)
# We need to wait a bit to make sure the threads pick up the download
time.sleep(self.THREAD_WAIT_TIME)
self.wait_for_downloads(timeout)
version = self.find_latest_version(lora)
if is_version:
version = lora_name
if version is None:
return None
self._touch_lora(version, True)
return lora["name"]
if timeout is not None:
# We need to wait a bit to make sure the threads pick up the download
time.sleep(self.THREAD_WAIT_TIME)
self.wait_for_downloads(timeout)
version = self.find_latest_version(lora)
if is_version:
version = lora_name
if version is None:
return None
self._touch_lora(version, True)
return lora["name"]
return None

def do_baselines_match(self, lora_name, model_details, is_version: bool = False):
self._check_for_refresh(lora_name)
Expand Down
Loading

0 comments on commit 9ee3518

Please sign in to comment.