Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: inject negative embeddings correctly #70

Merged
merged 5 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def _apply_aihorde_compatibility_hacks(self, payload):
return payload

def _final_pipeline_adjustments(self, payload, pipeline_data):

payload = deepcopy(payload)

# Process dynamic prompts
Expand Down Expand Up @@ -308,23 +307,31 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
continue
ti_inject = ti.get("inject_ti")
ti_strength = ti.get("strength", 1.0)
try:
ti_strength = float(ti_strength)
except (TypeError, ValueError):
ti_strength = 1.0
if type(ti_strength) not in [float, int]:
ti_strength = 1.0
ti_id = SharedModelManager.manager.ti.get_ti_id(str(ti["name"]))
if ti_inject == "prompt":
payload["prompt"] = f'(embedding:{ti_id}:{ti_strength}),{payload["prompt"]}'
elif ti_inject == "negprompt":
if "###" not in payload["prompt"]:
payload["prompt"] += "###"
payload["prompt"] = f'{payload["prompt"]},(embedding:{ti_id}:{ti_strength})'
SharedModelManager.manager.ti.touch_ti(ti_name)
# create negative prompt if empty
if "negative_prompt" not in payload:
payload["negative_prompt"] = ""

had_leading_comma = payload["negative_prompt"].startswith(",")

payload["negative_prompt"] = f'{payload["negative_prompt"]},(embedding:{ti_id}:{ti_strength})'
if not had_leading_comma:
payload["negative_prompt"] = payload["negative_prompt"].lstrip(",")
# Setup controlnet if required
# For LORAs we completely build the LORA section of the pipeline dynamically, as we have
# to handle n LORA models which form chained nodes in the pipeline.
# Note that we build this between several nodes, the model_loader, clip_skip and the sampler,
# plus the upscale sampler (used in hires fix) if there is one
if payload.get("loras") and SharedModelManager.manager.lora:

# Remove any requested LORAs that we don't have
valid_loras = []
for lora in payload.get("loras"):
Expand Down Expand Up @@ -366,7 +373,6 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
valid_loras.append(lora)
payload["loras"] = valid_loras
for lora_index, lora in enumerate(payload.get("loras")):

# Inject a lora node (first lora)
if lora_index == 0:
pipeline_data[f"lora_{lora_index}"] = {
Expand Down Expand Up @@ -395,7 +401,6 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
}

for lora_index, lora in enumerate(payload.get("loras")):

# The first LORA always connects to the model loader
if lora_index == 0:
self.generator.reconnect_input(pipeline_data, "lora_0.model", "model_loader")
Expand Down Expand Up @@ -545,7 +550,7 @@ def unlock_models(self, models):
self.generator.unlock_models(models)
logger.debug(f"Unlocked models {','.join(models)}")

def basic_inference(self, payload, rawpng=False):
def _get_validated_payload_and_pipeline_data(self, payload) -> tuple[dict, dict]:
# AIHorde hacks to payload
payload = self._apply_aihorde_compatibility_hacks(payload)
# Check payload types/values and normalise it's format
Expand All @@ -557,6 +562,10 @@ def basic_inference(self, payload, rawpng=False):
# Final adjustments to the pipeline
pipeline_data = self.generator.get_pipeline_data(pipeline)
payload = self._final_pipeline_adjustments(payload, pipeline_data)
return payload, pipeline_data

def basic_inference(self, payload, rawpng=False):
payload, pipeline_data = self._get_validated_payload_and_pipeline_data(payload)
models: list[str] = []
# Run the pipeline
try:
Expand Down
Binary file added images_expected/ti_bad_inject.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images_expected/ti_basic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images_expected/ti_inject.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 57 additions & 30 deletions tests/test_horde_ti.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
# test_horde_ti.py
import os
from datetime import datetime, timedelta
from pathlib import Path

import pytest
from PIL import Image

from hordelib.horde import HordeLib
from hordelib.shared_model_manager import SharedModelManager

from .testing_shared_functions import check_single_lora_image_similarity


class TestHordeTI:
def test_basic_ti(
@pytest.fixture(scope="class")
def basic_ti_payload_data(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
assert shared_model_manager.manager.ti

data = {
) -> dict:
return {
"sampler_name": "k_euler",
"cfg_scale": 8.0,
"denoising_strength": 1.0,
Expand All @@ -33,9 +30,9 @@ def test_basic_ti(
"control_type": None,
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"atmospheric lighting, embedding:7523###(embedding:7808:0.5), embedding:64870",
"prompt": "(embedding:7523:1.0),Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting###(embedding:7808:0.5),(embedding:64870:1.0)",
"tis": [
{"name": 7523},
{"name": 7808},
Expand All @@ -46,7 +43,15 @@ def test_basic_ti(
"model": stable_diffusion_model_name_for_testing,
}

pil_image = hordelib_instance.basic_inference(data)
def test_basic_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
basic_ti_payload_data,
):
assert shared_model_manager.manager.ti

pil_image = hordelib_instance.basic_inference(basic_ti_payload_data)
assert pil_image is not None
assert (
Path(os.path.join(shared_model_manager.manager.ti.modelFolderPath, "64870.safetensors")).exists() is True
Expand All @@ -57,9 +62,9 @@ def test_basic_ti(

def test_inject_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
basic_ti_payload_data: dict,
):
data = {
"sampler_name": "k_euler",
Expand All @@ -76,32 +81,44 @@ def test_inject_ti(
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting",
"tis": [
{"name": 7523, "inject_ti": "prompt", "strength": 0.5},
{"name": 7523, "inject_ti": "prompt", "strength": 1.0},
{"name": 7808, "inject_ti": "negprompt", "strength": 0.5},
{"name": 64870, "inject_ti": "negprompt", "strength": 0.5},
{"name": 64870, "inject_ti": "negprompt", "strength": 1.0},
],
"ddim_steps": 20,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}

payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(data)

basic_payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(
basic_ti_payload_data,
)

assert payload["prompt.text"] == basic_payload["prompt.text"]
assert payload["negative_prompt.text"] == basic_payload["negative_prompt.text"]

assert "(embedding:7523:1.0)" in payload["prompt.text"]
assert "(embedding:7808:0.5)" in payload["negative_prompt.text"]
assert "(embedding:64870:1.0)" in payload["negative_prompt.text"]

pil_image = hordelib_instance.basic_inference(data)
assert pil_image is not None

img_filename = "ti_inject.png"
pil_image.save(f"images/{img_filename}", quality=100)

# assert check_single_lora_image_similarity(
# f"images_expected/{img_filename}",
# pil_image,
# )
assert check_single_lora_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)

def test_bad_inject_ti(
self,
shared_model_manager: type[SharedModelManager],
hordelib_instance: HordeLib,
stable_diffusion_model_name_for_testing: str,
):
Expand All @@ -120,22 +137,32 @@ def test_bad_inject_ti(
"image_is_control": False,
"return_control_map": False,
"prompt": "Closeup portrait of a Lesotho teenage girl wearing a Seanamarena blanket, "
"walking in a field of flowers, holding a bundle of flowers, detailed background, light rays, "
"walking in a field of flowers, (holding a bundle of flowers:1.2), detailed background, light rays, "
"atmospheric lighting",
"tis": [
{"name": 7523, "inject_ti": "prompt", "strength": "0.5"},
{"name": 7808, "inject_ti": "negprompt", "strength": None},
{"name": 64870, "inject_ti": "YOLO", "strength": "YOLO"},
{"name": 7523, "inject_ti": "prompt", "strength": None},
{"name": 7808, "inject_ti": "negprompt", "strength": "0.5"},
{"name": 64870, "inject_ti": "negprompt", "strength": "1.0"},
{"name": 4629, "inject_ti": "YOLO", "strength": "YOLO"},
],
"ddim_steps": 20,
"n_iter": 1,
"model": stable_diffusion_model_name_for_testing,
}

payload, _ = hordelib_instance._get_validated_payload_and_pipeline_data(data)

assert "(embedding:7523:1.0)" in payload["prompt.text"]
assert "(embedding:7808:0.5)" in payload["negative_prompt.text"]
assert "(embedding:64870:1.0)" in payload["negative_prompt.text"]

pil_image = hordelib_instance.basic_inference(data)
assert pil_image is not None

# assert check_single_lora_image_similarity(
# f"images_expected/{img_filename}",
# pil_image,
# )
img_filename = "ti_bad_inject.png"
pil_image.save(f"images/{img_filename}", quality=100)

assert check_single_lora_image_similarity(
f"images_expected/{img_filename}",
pil_image,
)