Skip to content

Commit

Permalink
more visible divider in style list
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Joynt committed Nov 8, 2023
1 parent 799259c commit 65dc026
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions modules/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ class PromptStyle(typing.NamedTuple):
path: str = None


def clean_text(text: str):
def clean_text(text: str) -> str:
"""
Clean up the prompt and style text to make it easier to match against each other.
Iterating through a list of regular expressions and replacement strings, we
clean up the prompt and style text to make it easier to match against each
other.
"""
# A dictionary of regular expressions to tidy up the prompt text
re_list = [
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
("multiple spaces", re.compile("\s{2,}"), " "),
]
for _, regex, replace in re_list:
text = regex.sub(replace, text)

return text.strip()
return text.strip(", ")


def merge_prompts(style_prompt: str, prompt: str) -> str:
Expand All @@ -46,24 +47,27 @@ def apply_styles_to_prompt(prompt, styles):
return clean_text(prompt)


def extract_style_text_from_prompt(style_text, prompt):
def unwrap_style_text_from_prompt(style_text, prompt):
"""
Checks the prompt to see if the style text is wrapped around it. If so,
returns True plus the prompt text without the style text. Otherwise, returns
False with the original prompt.
Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
stripped_prompt = clean_text(prompt)
stripped_style_text = clean_text(style_text)
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text.
# return True and the "inner" prompt text that isn't part of the style.
left, right = stripped_style_text.split("{prompt}", 2)
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
return True, prompt
else:
# Work out whether the prompt ends with the style text. If so, we return
# True and the prompt text up to where the style text starts.
# Work out whether the given prompt ends with the style text. If so, we
# return True and the prompt text up to where the style text starts.
if stripped_prompt.endswith(stripped_style_text):
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]

Expand All @@ -75,7 +79,7 @@ def extract_style_text_from_prompt(style_text, prompt):
return False, prompt


def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
"""
Takes a style and compares it to the prompt and negative prompt. If the style
matches, returns True plus the prompt and negative prompt with the style text
Expand All @@ -84,13 +88,13 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
if not style.prompt and not style.negative_prompt:
return False, prompt, negative_prompt

match_positive, extracted_positive = extract_style_text_from_prompt(
match_positive, extracted_positive = unwrap_style_text_from_prompt(
style.prompt, prompt
)
if not match_positive:
return False, prompt, negative_prompt

match_negative, extracted_negative = extract_style_text_from_prompt(
match_negative, extracted_negative = unwrap_style_text_from_prompt(
style.negative_prompt, negative_prompt
)
if not match_negative:
Expand All @@ -116,20 +120,30 @@ def __init__(self, path: str):
self.reload()

def reload(self):
"""
Clears the style database and reloads the styles from the CSV file(s)
matching the path used to initialize the database.
"""
self.styles.clear()

path, filename = os.path.split(self.path)

# if the filename component of the path contains a wildcard,
# e.g. styles*.csv, load all matching files
if "*" in filename:
fileglob = filename.split("*")[0] + "*.csv"
filelist = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, fileglob):
self.styles[file.upper()] = PromptStyle(
f"{file.upper()}", None, None, "do_not_save"
filelist.append(file)
# Add a visible divider to the style list
divider = f"{'-' * 20} {file.upper()} {'-' * 20}"
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))
if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
return
elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
Expand All @@ -151,6 +165,27 @@ def load_from_csv(self, path: str):
row["name"], prompt, negative_prompt, path
)

def get_style_paths(self) -> list():
"""
Returns a list of all distinct paths, including the default path, of
files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")

return list(style_paths)

def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]

Expand All @@ -168,23 +203,10 @@ def apply_negative_styles_to_prompt(self, prompt, styles):
)

def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
# The path argument is deprecated, but kept for compatibility
_ = path

# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)

# Remove any paths for styles that are just list dividers
style_paths.remove("do_not_save")
style_paths = self.get_style_paths()

csv_names = [os.path.split(path)[1].lower() for path in style_paths]

Expand Down Expand Up @@ -215,7 +237,7 @@ def extract_styles_from_prompt(self, prompt, negative_prompt):
found_style = None

for style in applicable_styles:
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
style, prompt, negative_prompt
)
if is_match:
Expand Down

0 comments on commit 65dc026

Please sign in to comment.