Skip to content

Commit

Permalink
fixed randomization bug
Browse files Browse the repository at this point in the history
positive and negative prompts had same randomization
  • Loading branch information
SirVeggie committed Apr 1, 2024
1 parent f3dbdb4 commit b308fbe
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions scripts/style_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def is_opening(text, i):
def is_closing(text, i):
list = ['}', ')', ']', '>']
return text[i] in list and (i == 0 or text[i-1] != '\\')
def decode(text: str, hires: bool, seed: int):
def decode(text: str, hires: bool, neg: bool, seed: int):
depth = 0
start = -1
end = -1
mode = "random"
count = 0
splits = []
rand = random.Random(seed)
rand = random.Random(seed + (1 if neg else 0))

if len(text) == 0:
return text
Expand Down Expand Up @@ -83,7 +83,7 @@ def decode(text: str, hires: bool, seed: int):

if end != -1:
if mode == "hr" and len(splits) > 1:
print("Warning: multiple splits in hr mode")
logger.error("Warning: multiple splits in hr mode")
return text

if mode == "hr" and check_feature(extn_hires):
Expand All @@ -94,7 +94,6 @@ def decode(text: str, hires: bool, seed: int):

elif mode == "random" and check_feature(extn_random):
parts = []
print(text[start+1:end])
if len(splits) == 0:
parts.append(text[start+1:end])
else:
Expand Down Expand Up @@ -182,7 +181,7 @@ def rewrite_prompt(prompt: str, neg: bool, hires: bool, seed: int):
depth = 0
previous_prompt = prompt
while depth < 5:
prompt = decode(prompt, hires, seed)
prompt = decode(prompt, hires, neg, seed)

for name in style_names:
if name not in prompt:
Expand Down Expand Up @@ -214,8 +213,6 @@ def rewrite_prompt(prompt: str, neg: bool, hires: bool, seed: int):
# check if we're doing t2i with HR
is_t2i = isinstance(p, StableDiffusionProcessingTxt2Img)
hr_enabled = p.enable_hr if is_t2i else False

# logger.info(f"{extn_name} processing...")

if check_feature(extn_info):
orig_pos_prompt = deepcopy(p.all_prompts[0])
Expand Down Expand Up @@ -250,5 +247,4 @@ def rewrite_prompt(prompt: str, neg: bool, hires: bool, seed: int):

if check_feature(extn_info):
p.extra_generation_params.setdefault(TS_PROMPT, orig_pos_prompt)
p.extra_generation_params.setdefault(TS_NEG, orig_neg_prompt)
# logger.info(f"{extn_name} processing done.")
p.extra_generation_params.setdefault(TS_NEG, orig_neg_prompt)

0 comments on commit b308fbe

Please sign in to comment.