diff --git a/modules/styles.py b/modules/styles.py index 62e31218a94..503d1feb89e 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -26,7 +26,7 @@ def clean_text(text: str): for _, regex, replace in re_list: text = regex.sub(replace, text) - return text + return text.strip() def merge_prompts(style_prompt: str, prompt: str) -> str: @@ -47,22 +47,26 @@ def apply_styles_to_prompt(prompt, styles): def extract_style_text_from_prompt(style_text, prompt): - cleaned = {"style": clean_text(style_text), "prompt": clean_text(prompt)} - - if "{prompt}" in cleaned["style"]: + """ + Checks the prompt to see if the style text is wrapped around it. If so, + returns True plus the prompt text without the style text. Otherwise, returns + False with the original prompt. + """ + stripped_prompt = clean_text(prompt) + stripped_style_text = clean_text(style_text) + if "{prompt}" in stripped_style_text: # Work out whether the prompt is wrapped in the style text. If so, we # return True and the "inner" prompt text. - left, right = cleaned["style"].split("{prompt}", 2) - if cleaned["prompt"].startswith(left) and cleaned["prompt"].endswith(right): - prompt = cleaned["prompt"][len(left) : len(cleaned["prompt"]) - len(right)] - if prompt.endswith(", "): - prompt = prompt[:-2] + left, right = stripped_style_text.split("{prompt}", 2) + if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): + prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)] return True, prompt else: # Work out whether the prompt ends with the style text. If so, we return # True and the prompt text up to where the style text starts. - if cleaned["prompt"].endswith(cleaned["style"]): - prompt = cleaned["prompt"][: len(cleaned["prompt"]) - len(cleaned["style"])] + if stripped_prompt.endswith(stripped_style_text): + prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)] + if prompt.endswith(", "): prompt = prompt[:-2] @@ -72,6 +76,11 @@ def extract_style_text_from_prompt(style_text, prompt): def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): + """ + Takes a style and compares it to the prompt and negative prompt. If the style + matches, returns True plus the prompt and negative prompt with the style text + removed. Otherwise, returns False with the original prompt and negative prompt. + """ if not style.prompt and not style.negative_prompt: return False, prompt, negative_prompt