Skip to content

Commit

Permalink
feat: inline lora optimisations (#2967)
Browse files Browse the repository at this point in the history
* feat: add performance loras to the end of the loras array

* fix: resolve circular dependency for unit tests

* feat: allow multiple matches for each token, optimize and extract method cleanup_prompt

* fix: update unit tests

* feat: ignore custom wildcards
  • Loading branch information
mashb1t authored May 20, 2024
1 parent c995511 commit 65a8b25
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 36 deletions.
12 changes: 7 additions & 5 deletions modules/async_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,12 @@ def handler(async_task):

steps = performance_selection.steps()

performance_loras = []

if performance_selection == Performance.EXTREME_SPEED:
print('Enter LCM mode.')
progressbar(async_task, 1, 'Downloading LCM components ...')
loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
performance_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]

if refiner_model_name != 'None':
print(f'Refiner disabled in LCM mode.')
Expand All @@ -259,7 +261,7 @@ def handler(async_task):
elif performance_selection == Performance.LIGHTNING:
print('Enter Lightning mode.')
progressbar(async_task, 1, 'Downloading Lightning components ...')
loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]
performance_loras += [(modules.config.downloading_sdxl_lightning_lora(), 1.0)]

if refiner_model_name != 'None':
print(f'Refiner disabled in Lightning mode.')
Expand All @@ -278,7 +280,7 @@ def handler(async_task):
elif performance_selection == Performance.HYPER_SD:
print('Enter Hyper-SD mode.')
progressbar(async_task, 1, 'Downloading Hyper-SD components ...')
loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]
performance_loras += [(modules.config.downloading_sdxl_hyper_sd_lora(), 0.8)]

if refiner_model_name != 'None':
print(f'Refiner disabled in Hyper-SD mode.')
Expand Down Expand Up @@ -458,8 +460,8 @@ def handler(async_task):

progressbar(async_task, 2, 'Loading models ...')

loras = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)

loras, prompt = parse_lora_references_from_prompt(prompt, loras, modules.config.default_max_lora_number)
loras += performance_loras
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras,
use_synthetic_refiner=use_synthetic_refiner, vae_name=vae_name)
Expand Down
3 changes: 1 addition & 2 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import modules.sdxl_styles

from modules.model_loader import load_file_from_url
from modules.util import makedirs_with_log
from modules.extra_utils import get_files_from_folder
from modules.extra_utils import makedirs_with_log, get_files_from_folder
from modules.flags import OutputFormat, Performance, MetadataScheme


Expand Down
6 changes: 6 additions & 0 deletions modules/extra_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import os

def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')


def get_files_from_folder(folder_path, extensions=None, name_filter=None):
if not os.path.isdir(folder_path):
Expand Down
58 changes: 42 additions & 16 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

from PIL import Image

import modules.config
import modules.sdxl_styles

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)


# Regexp compiled once. Matches entries with the following pattern:
# <lora:some_lora:1>
# <lora:aNotherLora:-1.6>
LORAS_PROMPT_PATTERN = re.compile(r".* <lora : ([^:]+) : ([+-]? (?: (?:\d+ (?:\.\d*)?) | (?:\.\d+)))> .*", re.X)
LORAS_PROMPT_PATTERN = re.compile(r"(<lora:([^:]+):([+-]?(?:\d+(?:\.\d*)?|\.\d+))>)", re.X)

HASH_SHA256_LENGTH = 10

Expand Down Expand Up @@ -372,31 +372,57 @@ def get_file_from_folder_list(name, folders):
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))


def makedirs_with_log(path):
try:
os.makedirs(path, exist_ok=True)
except OSError as error:
print(f'Directory {path} could not be created, reason: {error}')
def get_enabled_loras(loras: list, remove_none=True) -> list:
return [(lora[1], lora[2]) for lora in loras if lora[0] and (lora[1] != 'None' if remove_none else True)]


def get_enabled_loras(loras: list) -> list:
return [(lora[1], lora[2]) for lora in loras if lora[0]]
def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5,
prompt_cleanup=True, deduplicate_loras=True) -> tuple[List[Tuple[AnyStr, float]], str]:
found_loras = []
prompt_without_loras = ""
for token in prompt.split(" "):
matches = LORAS_PROMPT_PATTERN.findall(token)

if matches:
for match in matches:
found_loras.append((f"{match[1]}.safetensors", float(match[2])))
prompt_without_loras += token.replace(match[0], '')
else:
prompt_without_loras += token
prompt_without_loras += ' '

cleaned_prompt = prompt_without_loras[:-1]
if prompt_cleanup:
cleaned_prompt = cleanup_prompt(prompt_without_loras)

def parse_lora_references_from_prompt(prompt: str, loras: List[Tuple[AnyStr, float]], loras_limit: int = 5) -> List[Tuple[AnyStr, float]]:
new_loras = []
updated_loras = []
for token in prompt.split(","):
m = LORAS_PROMPT_PATTERN.match(token)
lora_names = [lora[0] for lora in loras]
for found_lora in found_loras:
if deduplicate_loras and found_lora[0] in lora_names:
continue
new_loras.append(found_lora)

if m:
new_loras.append((f"{m.group(1)}.safetensors", float(m.group(2))))
if len(new_loras) == 0:
return loras, cleaned_prompt

updated_loras = []
for lora in loras + new_loras:
if lora[0] != "None":
updated_loras.append(lora)

return updated_loras[:loras_limit]
return updated_loras[:loras_limit], cleaned_prompt


def cleanup_prompt(prompt):
prompt = re.sub(' +', ' ', prompt)
prompt = re.sub(',+', ',', prompt)
cleaned_prompt = ''
for token in prompt.split(','):
token = token.strip()
if token == '':
continue
cleaned_prompt += token + ', '
return cleaned_prompt[:-2]


def apply_wildcards(wildcard_text, rng, i, read_wildcards_in_order) -> str:
Expand Down
41 changes: 28 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ def test_can_parse_tokens_with_lora(self):
test_cases = [
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 5),
"output": [("hey-lora.safetensors", 0.4), ("you-lora.safetensors", 0.2)],
"output": (
[('hey-lora.safetensors', 0.4), ('you-lora.safetensors', 0.2)], 'some prompt, very cool, cool'),
},
# Test can not exceed limit
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4>, cool <lora:you-lora:0.2>", [], 1),
"output": [("hey-lora.safetensors", 0.4)],
"output": (
[('hey-lora.safetensors', 0.4)],
'some prompt, very cool, cool'
),
},
# test Loras from UI take precedence over prompt
{
Expand All @@ -22,22 +26,33 @@ def test_can_parse_tokens_with_lora(self):
[("hey-lora.safetensors", 0.4)],
5,
),
"output": [
("hey-lora.safetensors", 0.4),
("l1.safetensors", 0.4),
("l2.safetensors", -0.2),
("l3.safetensors", 0.3),
("l4.safetensors", 0.5),
],
"output": (
[
('hey-lora.safetensors', 0.4),
('l1.safetensors', 0.4),
('l2.safetensors', -0.2),
('l3.safetensors', 0.3),
('l4.safetensors', 0.5)
],
'some prompt, very cool'
)
},
# Test lora specification not separated by comma are ignored, only latest specified is used
{
"input": ("some prompt, very cool, <lora:hey-lora:0.4><lora:you-lora:0.2>", [], 3),
"output": [("you-lora.safetensors", 0.2)],
"output": (
[
('hey-lora.safetensors', 0.4),
('you-lora.safetensors', 0.2)
],
'some prompt, very cool, <lora:you-lora:0.2><lora:hey-lora:0.4>'
),
},
{
"input": ("<lora:foo:1..2>, <lora:bar:.>, <lora:baz:+> and <lora:quux:>", [], 6),
"output": []
"input": ("<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>", [], 6),
"output": (
[],
'<lora:foo:1..2>, <lora:bar:.>, <test:1.0>, <lora:baz:+> and <lora:quux:>'
)
}
]

Expand Down
8 changes: 8 additions & 0 deletions wildcards/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
*.txt
!animal.txt
!artist.txt
!color.txt
!color_flower.txt
!extended-color.txt
!flower.txt
!nationality.txt

0 comments on commit 65a8b25

Please sign in to comment.