Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Nov 12, 2024
1 parent 2635906 commit 66820ed
Show file tree
Hide file tree
Showing 17 changed files with 113 additions and 55 deletions.
2 changes: 1 addition & 1 deletion modules/face/instantid.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def instant_id(p: processing.StableDiffusionProcessing, app, source_images, stre
processing.process_init(p)
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
orig_prompt_attention = shared.opts.prompt_attention
shared.opts.data['prompt_attention'] = 'Fixed attention' # otherwise need to deal with class_tokens_mask
shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
p.task_args['image_embeds'] = face_embeds[0].shape # placeholder
p.task_args['image'] = face_images[0]
p.task_args['controlnet_conditioning_scale'] = float(conditioning)
Expand Down
2 changes: 1 addition & 1 deletion modules/face/photomaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def photo_maker(p: processing.StableDiffusionProcessing, input_images, trigger,
shared.sd_model.to(dtype=devices.dtype)

orig_prompt_attention = shared.opts.prompt_attention
shared.opts.data['prompt_attention'] = 'Fixed attention' # otherwise need to deal with class_tokens_mask
shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
p.task_args['input_id_images'] = input_images
p.task_args['start_merge_step'] = int(start * p.steps)
p.task_args['prompt'] = p.all_prompts[0] if p.all_prompts is not None else p.prompt
Expand Down
9 changes: 5 additions & 4 deletions modules/processing_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,11 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2

debug(f'Diffusers pipeline possible: {possible}')
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompts(prompts, negative_prompts, prompts_2, negative_prompts_2)
parser = 'Fixed attention'
steps = kwargs.get("num_inference_steps", None) or len(getattr(p, 'timesteps', ['1']))
clip_skip = kwargs.pop("clip_skip", 1)

# prompt_parser_diffusers.fix_position_ids(model)
if shared.opts.prompt_attention != 'Fixed attention' and 'Onnx' not in model.__class__.__name__ and (
parser = 'fixed'
if shared.opts.prompt_attention != 'fixed' and 'Onnx' not in model.__class__.__name__ and (
'StableDiffusion' in model.__class__.__name__ or
'StableCascade' in model.__class__.__name__ or
'Flux' in model.__class__.__name__
Expand All @@ -125,6 +124,8 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
if os.environ.get('SD_PROMPT_DEBUG', None) is not None:
errors.display(e, 'Prompt parser encode')
timer.process.record('encode', reset=False)
else:
prompt_parser_diffusers.embedder = None

if 'prompt' in possible:
if 'OmniGen' in model.__class__.__name__:
Expand Down Expand Up @@ -156,7 +157,7 @@ def set_pipeline_args(p, model, prompts: list, negative_prompts: list, prompts_2
else:
args['negative_prompt'] = negative_prompts

if 'clip_skip' in possible and parser == 'Fixed attention':
if 'clip_skip' in possible and parser == 'fixed':
if clip_skip == 1:
pass # clip_skip = None
else:
Expand Down
17 changes: 10 additions & 7 deletions modules/processing_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from modules.processing_class import StableDiffusionProcessing


debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
if not shared.native:
from modules import sd_hijack
else:
Expand Down Expand Up @@ -39,27 +40,27 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
ops.reverse()
args = {
# basic
"Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
"Sampler": p.sampler_name if p.sampler_name != 'Default' else None,
"Steps": p.steps,
"Seed": all_seeds[index],
"Sampler": p.sampler_name if p.sampler_name != 'Default' else None,
"Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
"CFG scale": p.cfg_scale if p.cfg_scale > 1.0 else None,
"CFG end": p.cfg_end if p.cfg_end < 1.0 else None,
"Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
"Clip skip": p.clip_skip if p.clip_skip > 1 else None,
"Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None,
"Parser": shared.opts.prompt_attention.split()[0],
"Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''),
"Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash),
"VAE": (None if not shared.opts.add_model_name_to_info or sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD',
"Seed resize from": None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}",
"Clip skip": p.clip_skip if p.clip_skip > 1 else None,
"Prompt2": p.refiner_prompt if len(p.refiner_prompt) > 0 else None,
"Negative2": p.refiner_negative if len(p.refiner_negative) > 0 else None,
"Styles": "; ".join(p.styles) if p.styles is not None and len(p.styles) > 0 else None,
"Tiling": p.tiling if p.tiling else None,
# sdnext
"Backend": 'Diffusers' if shared.native else 'Original',
"App": 'SD.Next',
"Version": git_commit,
"Backend": 'Diffusers' if shared.native else 'Original',
"Pipeline": 'LDM',
"Parser": shared.opts.prompt_attention.split()[0],
"Comment": comment,
"Operations": '; '.join(ops).replace('"', '') if len(p.ops) > 0 else 'none',
}
Expand Down Expand Up @@ -165,7 +166,9 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
if isinstance(v, str):
if len(v) == 0 or v == '0x0':
del args[k]
debug(f'Infotext: args={args}')
params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in args.items()])
negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
infotext = f"{all_prompts[index]}{negative_prompt_text}\n{params_text}".strip()
debug(f'Infotext: "{infotext}"')
return infotext
16 changes: 8 additions & 8 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,11 @@ def parse_prompt_attention(text):
res = []
round_brackets = []
square_brackets = []
if opts.prompt_attention == 'Fixed attention':
if opts.prompt_attention == 'fixed':
res = [[text, 1.0]]
debug(f'Prompt: parser="{opts.prompt_attention}" {res}')
return res
elif opts.prompt_attention == 'Compel parser':
elif opts.prompt_attention == 'compel':
conjunction = Compel.parse_prompt_string(text)
if conjunction is None or conjunction.prompts is None or conjunction.prompts is None or len(conjunction.prompts[0].children) == 0:
return [["", 1.0]]
Expand All @@ -321,7 +321,7 @@ def parse_prompt_attention(text):
res.append([frag.text, frag.weight])
debug(f'Prompt: parser="{opts.prompt_attention}" {res}')
return res
elif opts.prompt_attention == 'A1111 parser':
elif opts.prompt_attention == 'a1111':
re_attention = re_attention_v1
whitespace = ''
else:
Expand Down Expand Up @@ -360,7 +360,7 @@ def multiply_range(start_position, multiplier):
for i, part in enumerate(parts):
if i > 0:
res.append(["BREAK", -1])
if opts.prompt_attention == 'Full parser':
if opts.prompt_attention == 'native':
part = re_clean.sub("", part)
part = re_whitespace.sub(" ", part).strip()
if len(part) == 0:
Expand Down Expand Up @@ -392,15 +392,15 @@ def multiply_range(start_position, multiplier):
log.info(f'Schedules: {all_schedules}')
for schedule in all_schedules:
log.info(f'Schedule: {schedule[0]}')
opts.data['prompt_attention'] = 'Fixed attention'
opts.data['prompt_attention'] = 'fixed'
output_list = parse_prompt_attention(schedule[1])
log.info(f' Fixed: {output_list}')
opts.data['prompt_attention'] = 'Compel parser'
opts.data['prompt_attention'] = 'compel'
output_list = parse_prompt_attention(schedule[1])
log.info(f' Compel: {output_list}')
opts.data['prompt_attention'] = 'A1111 parser'
opts.data['prompt_attention'] = 'a1111'
output_list = parse_prompt_attention(schedule[1])
log.info(f' A1111: {output_list}')
opts.data['prompt_attention'] = 'Full parser'
opts.data['prompt_attention'] = 'native'
log.info = parse_prompt_attention(schedule[1])
log.info(f' Full: {output_list}')
41 changes: 27 additions & 14 deletions modules/prompt_parser_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
self.prompts = prompts
self.negative_prompts = negative_prompts
self.batchsize = len(self.prompts)
self.attention = None
self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible
self.steps = steps
self.clip_skip = clip_skip
Expand Down Expand Up @@ -75,6 +76,10 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
def checkcache(self, p):
if shared.opts.sd_textencoder_cache_size == 0:
return False
if self.attention != shared.opts.prompt_attention:
debug(f"Prompt change: parser={shared.opts.prompt_attention}")
cache.clear()
return False

def flatten(xss):
return [x for xs in xss for x in xs]
Expand All @@ -97,23 +102,22 @@ def flatten(xss):
'positive_pooleds': self.positive_pooleds,
'negative_pooleds': self.negative_pooleds,
}
debug(f"Prompt cache: Adding {key}")
debug(f"Prompt cache: add={key}")
while len(cache) > int(shared.opts.sd_textencoder_cache_size):
cache.popitem(last=False)
if item:
self.__dict__.update(cache[key])
cache.move_to_end(key)
if self.allsame and len(self.prompt_embeds) < self.batchsize: # If current batch larger than cached
if self.allsame and len(self.prompt_embeds) < self.batchsize:
self.prompt_embeds = [self.prompt_embeds[0]] * self.batchsize
self.positive_pooleds = [self.positive_pooleds[0]] * self.batchsize
self.negative_prompt_embeds = [self.negative_prompt_embeds[0]] * self.batchsize
self.negative_pooleds = [self.negative_pooleds[0]] * self.batchsize
debug(f"Prompt cache: Retrieving {key}")
debug(f"Prompt cache: get={key}")
return True

def compare_prompts(self):
same = (self.prompts == [self.prompts[0]] * len(self.prompts) and
self.negative_prompts == [self.negative_prompts[0]] * len(self.negative_prompts))
same = (self.prompts == [self.prompts[0]] * len(self.prompts) and self.negative_prompts == [self.negative_prompts[0]] * len(self.negative_prompts))
if same:
self.prompts = [self.prompts[0]]
self.negative_prompts = [self.negative_prompts[0]]
Expand All @@ -123,6 +127,7 @@ def prepare_schedule(self, prompt, negative_prompt):
self.positive_schedule, scheduled = get_prompt_schedule(prompt, self.steps)
self.negative_schedule, neg_scheduled = get_prompt_schedule(negative_prompt, self.steps)
self.scheduled_prompt = scheduled or neg_scheduled
debug(f"Prompt schedule: positive={self.positive_schedule} negative={self.negative_schedule} scheduled={scheduled}")

def scheduled_encode(self, pipe, batchidx):
prompt_dict = {} # index cache
Expand All @@ -138,20 +143,21 @@ def scheduled_encode(self, pipe, batchidx):
prompt_dict[positive_prompt+negative_prompt] = i

def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index
self.prompt_embeds[batchidx].append(self.prompt_embeds[batchidx][idx])
self.negative_prompt_embeds[batchidx].append(self.negative_prompt_embeds[batchidx][idx])
if len(self.prompt_embeds[batchidx]) > 0:
self.prompt_embeds[batchidx].append(self.prompt_embeds[batchidx][idx])
if len(self.negative_prompt_embeds[batchidx]) > 0:
self.negative_prompt_embeds[batchidx].append(self.negative_prompt_embeds[batchidx][idx])
if len(self.positive_pooleds[batchidx]) > 0:
self.positive_pooleds[batchidx].append(self.positive_pooleds[batchidx][idx])
if len(self.negative_pooleds[batchidx]) > 0:
self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx])

def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
if shared.opts.prompt_attention == "xhinker parser" or 'Flux' in pipe.__class__.__name__:
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(
pipe, positive_prompt, negative_prompt, self.clip_skip)
self.attention = shared.opts.prompt_attention
if self.attention == "xhinker" or 'Flux' in pipe.__class__.__name__:
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
else:
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(
pipe, positive_prompt, negative_prompt, self.clip_skip)
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
if prompt_embed is not None:
self.prompt_embeds[batchidx].append(prompt_embed)
if negative_embed is not None:
Expand Down Expand Up @@ -311,6 +317,7 @@ def get_tokens(msg, prompt):
tokens.append(f'UNK_{i}')
token_count = len(ids) - int(has_bos_token) - int(has_eos_token)
debug(f'Prompt tokenizer: type={msg} tokens={token_count} {tokens}')
return token_count


def normalize_prompt(pairs: list):
Expand Down Expand Up @@ -338,6 +345,12 @@ def get_prompts_with_weights(prompt: str):
if shared.opts.prompt_mean_norm:
texts_and_weights = normalize_prompt(texts_and_weights)
texts, text_weights = zip(*texts_and_weights)
if debug_enabled:
all_tokens = 0
for text in texts:
tokens = get_tokens('section', text)
all_tokens += tokens
debug(f'Prompt tokenizer: parser={shared.opts.prompt_attention} tokens={all_tokens}')
debug(f'Prompt: weights={texts_and_weights} time={(time.time() - t0):.3f}')
return texts, text_weights

Expand Down Expand Up @@ -479,7 +492,7 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
# negative prompt has no keywords
embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True)
negative_prompt_embeds.append(embed)
debug(f'Prompt: unpadded shape={prompt_embeds[0].shape} TE{i+1} ptokens={torch.count_nonzero(ptokens)} ntokens={torch.count_nonzero(ntokens)} time={(time.time() - t0):.3f}')
debug(f'Prompt: unpadded={prompt_embeds[0].shape} TE{i+1} ptokens={torch.count_nonzero(ptokens)} ntokens={torch.count_nonzero(ntokens)} time={(time.time() - t0):.3f}')
if SD3:
t0 = time.time()
pooled_prompt_embeds.append(embedding_providers[0].get_pooled_embeddings(texts=positives[0] if len(positives[0]) == 1 else [" ".join(positives[0])], device=device))
Expand All @@ -488,7 +501,7 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
negative_pooled_prompt_embeds.append(embedding_providers[1].get_pooled_embeddings(texts=negatives[-1] if len(negatives[-1]) == 1 else [" ".join(negatives[-1])], device=device))
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds, dim=-1)
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds, dim=-1)
debug(f'Prompt: pooled shape={pooled_prompt_embeds[0].shape} time={(time.time() - t0):.3f}')
debug(f'Prompt: pooled={pooled_prompt_embeds[0].shape} time={(time.time() - t0):.3f}')
elif prompt_embeds[-1].shape[-1] > 768:
t0 = time.time()
if shared.opts.diffusers_pooled == "weighted":
Expand Down
2 changes: 1 addition & 1 deletion modules/prompt_parser_xhinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def get_weighted_text_embeddings_sd3(
# ---------------------- get neg t5 embeddings -------------------------
neg_prompt_tokens_3 = torch.tensor([neg_prompt_tokens_3], dtype=torch.long)

t5_neg_prompt_embeds = pipe.text_encoder_3(neg_prompt_tokens_3.to(pipe.pipe.text_encoder_3.device))[0].squeeze(0)
t5_neg_prompt_embeds = pipe.text_encoder_3(neg_prompt_tokens_3.to(pipe.text_encoder_3.device))[0].squeeze(0)
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(device=pipe.text_encoder_3.device)

# add weight to neg t5 embeddings
Expand Down
Loading

0 comments on commit 66820ed

Please sign in to comment.