Skip to content

Commit

Permalink
Ensure default styles file is loaded (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
MisterSeajay authored Jan 19, 2024
1 parent cb5b335 commit 8126540
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions modules/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,30 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
return True, extracted_positive, extracted_negative


def _format_divider(file: str) -> str:
"""
Creates a divider for the style list.
"""
half_len = round(len(file) / 2)
divider = f"{'-' * (20 - half_len)} {file.upper()}"
divider = f"{divider} {'-' * (40 - len(divider))}"
return divider


class StyleDatabase:
def __init__(self, path: str):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

# The default path will be self.path with any wildcard removed. If it
# doesn't exist, the reload() method updates this to be 'styles.csv'.
self.default_file = "styles.csv"
folder, file = os.path.split(self.path)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)

self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

self.reload()

def reload(self):
Expand All @@ -108,17 +120,27 @@ def reload(self):
if fnmatch.fnmatch(file, fileglob):
filelist.append(file)
# Add a visible divider to the style list
half_len = round(len(file) / 2)
divider = f"{'-' * (20 - half_len)} {file.upper()}"
divider = f"{divider} {'-' * (40 - len(divider))}"
divider = _format_divider(file)
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))

# Ensure the default file is loaded, else its contents may be lost:
if os.path.split(self.default_path)[1] not in filelist:
self.default_path = os.path.join(path, self.default_file)
divider = _format_divider(self.default_file)
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
self.load_from_csv(os.path.join(path, self.default_file))

if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
self.load_from_csv(self.default_path)
return

elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
Expand Down

0 comments on commit 8126540

Please sign in to comment.