Skip to content

Commit

Permalink
Use clean_text()
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Joynt committed Nov 8, 2023
1 parent a3e0b65 commit 799259c
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions modules/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def clean_text(text: str):
for _, regex, replace in re_list:
text = regex.sub(replace, text)

return text
return text.strip()


def merge_prompts(style_prompt: str, prompt: str) -> str:
Expand All @@ -47,22 +47,26 @@ def apply_styles_to_prompt(prompt, styles):


def extract_style_text_from_prompt(style_text, prompt):
cleaned = {"style": clean_text(style_text), "prompt": clean_text(prompt)}

if "{prompt}" in cleaned["style"]:
"""
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.
"""
stripped_prompt = clean_text(prompt)
stripped_style_text = clean_text(style_text)
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text.
left, right = cleaned["style"].split("{prompt}", 2)
if cleaned["prompt"].startswith(left) and cleaned["prompt"].endswith(right):
prompt = cleaned["prompt"][len(left) : len(cleaned["prompt"]) - len(right)]
if prompt.endswith(", "):
prompt = prompt[:-2]
left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
# Work out whether the prompt ends with the style text. If so, we return
# True and the prompt text up to where the style text starts.
if cleaned["prompt"].endswith(cleaned["style"]):
prompt = cleaned["prompt"][: len(cleaned["prompt"]) - len(cleaned["style"])]
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]

if prompt.endswith(", "):
prompt = prompt[:-2]

Expand All @@ -72,6 +76,11 @@ def extract_style_text_from_prompt(style_text, prompt):


def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
"""
Takes a style and compares it to the prompt and negative prompt. If the style
matches, returns True plus the prompt and negative prompt with the style text
removed. Otherwise, returns False with the original prompt and negative prompt.
"""
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt

Expand Down

0 comments on commit 799259c

Please sign in to comment.