Skip to content

Commit

Permalink
pulid with refine pass
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Nov 13, 2024
1 parent c77370e commit a0d55a5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 35 deletions.
6 changes: 6 additions & 0 deletions modules/processing_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def create_latents(image, p, dtype=None, device=None):

def full_vae_decode(latents, model):
t0 = time.time()
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
model = model.pipe
if model is None or not hasattr(model, 'vae'):
shared.log.error('VAE not found in model')
return []
Expand Down Expand Up @@ -148,6 +150,8 @@ def taesd_vae_encode(image):
def vae_decode(latents, model, output_type='np', full_quality=True, width=None, height=None):
t0 = time.time()
model = model or shared.sd_model
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
model = model.pipe
if latents is None or not torch.is_tensor(latents): # already decoded
return latents
prev_job = shared.state.job
Expand Down Expand Up @@ -196,6 +200,8 @@ def vae_decode(latents, model, output_type='np', full_quality=True, width=None,
def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable
if shared.state.interrupted or shared.state.skipped:
return []
if not hasattr(model, 'vae') and hasattr(model, 'pipe'):
model = model.pipe
if not hasattr(model, 'vae'):
shared.log.error('VAE not found in model')
return []
Expand Down
76 changes: 42 additions & 34 deletions modules/prompt_parser_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,25 @@
embedder = None


def prompt_compatible():
def prompt_compatible(pipe = None):
pipe = pipe or shared.sd_model
if (
'StableDiffusion' not in shared.sd_model.__class__.__name__ and
'DemoFusion' not in shared.sd_model.__class__.__name__ and
'StableCascade' not in shared.sd_model.__class__.__name__ and
'Flux' not in shared.sd_model.__class__.__name__
'StableDiffusion' not in pipe.__class__.__name__ and
'DemoFusion' not in pipe.__class__.__name__ and
'StableCascade' not in pipe.__class__.__name__ and
'Flux' not in pipe.__class__.__name__
):
shared.log.warning(f"Prompt parser not supported: {shared.sd_model.__class__.__name__}")
shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
return False
return True


def prepare_model():
pipe = shared.sd_model
def prepare_model(pipe = None):
pipe = pipe or shared.sd_model
if not hasattr(pipe, "text_encoder") and hasattr(shared.sd_model, "pipe"):
pipe = pipe.pipe
if not hasattr(pipe, "text_encoder"):
return None
if shared.opts.diffusers_offload_mode == "balanced":
pipe = sd_models.apply_balanced_offload(pipe)
elif hasattr(pipe, "maybe_free_model_hooks"):
Expand Down Expand Up @@ -62,7 +67,10 @@ def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
earlyout = self.checkcache(p)
if earlyout:
return
pipe = prepare_model()
pipe = prepare_model(p.sd_model)
if pipe is None:
shared.log.error("Prompt encode: cannot find text encoder in model")
return
# per prompt in batch
for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts)):
self.prepare_schedule(prompt, negative_prompt)
Expand Down Expand Up @@ -168,8 +176,8 @@ def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
self.negative_pooleds[batchidx].append(negative_pooled)

if debug_enabled:
get_tokens('positive', positive_prompt)
get_tokens('negative', negative_prompt)
get_tokens(pipe, 'positive', positive_prompt)
get_tokens(pipe, 'negative', negative_prompt)
pipe = prepare_model()

def __call__(self, key, step=0):
Expand Down Expand Up @@ -288,25 +296,25 @@ def get_prompt_schedule(prompt, steps):
return temp, len(schedule) > 1


def get_tokens(msg, prompt):
def get_tokens(pipe, msg, prompt):
global token_dict, token_type # pylint: disable=global-statement
if not shared.native:
return 0
if shared.sd_loaded and hasattr(shared.sd_model, 'tokenizer') and shared.sd_model.tokenizer is not None:
if shared.sd_loaded and hasattr(pipe, 'tokenizer') and pipe.tokenizer is not None:
if token_dict is None or token_type != shared.sd_model_type:
token_type = shared.sd_model_type
fn = shared.sd_model.tokenizer.name_or_path
fn = pipe.tokenizer.name_or_path
if fn.endswith('tokenizer'):
fn = os.path.join(shared.sd_model.tokenizer.name_or_path, 'vocab.json')
fn = os.path.join(pipe.tokenizer.name_or_path, 'vocab.json')
else:
fn = os.path.join(shared.sd_model.tokenizer.name_or_path, 'tokenizer', 'vocab.json')
fn = os.path.join(pipe.tokenizer.name_or_path, 'tokenizer', 'vocab.json')
token_dict = shared.readfile(fn, silent=True)
for k, v in shared.sd_model.tokenizer.added_tokens_decoder.items():
for k, v in pipe.tokenizer.added_tokens_decoder.items():
token_dict[str(v)] = k
shared.log.debug(f'Tokenizer: words={len(token_dict)} file="{fn}"')
has_bos_token = shared.sd_model.tokenizer.bos_token_id is not None
has_eos_token = shared.sd_model.tokenizer.eos_token_id is not None
ids = shared.sd_model.tokenizer(prompt)
has_bos_token = pipe.tokenizer.bos_token_id is not None
has_eos_token = pipe.tokenizer.eos_token_id is not None
ids = pipe.tokenizer(prompt)
ids = getattr(ids, 'input_ids', [])
tokens = []
for i in ids:
Expand Down Expand Up @@ -337,18 +345,18 @@ def normalize_prompt(pairs: list):
return pairs


def get_prompts_with_weights(prompt: str):
def get_prompts_with_weights(pipe, prompt: str):
t0 = time.time()
manager = DiffusersTextualInversionManager(shared.sd_model, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
prompt = manager.maybe_convert_prompt(prompt, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
manager = DiffusersTextualInversionManager(pipe, pipe.tokenizer or pipe.tokenizer_2)
prompt = manager.maybe_convert_prompt(prompt, pipe.tokenizer or pipe.tokenizer_2)
texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
if shared.opts.prompt_mean_norm:
texts_and_weights = normalize_prompt(texts_and_weights)
texts, text_weights = zip(*texts_and_weights)
if debug_enabled:
all_tokens = 0
for text in texts:
tokens = get_tokens('section', text)
tokens = get_tokens(pipe, 'section', text)
all_tokens += tokens
debug(f'Prompt tokenizer: parser={shared.opts.prompt_attention} tokens={all_tokens}')
debug(f'Prompt: weights={texts_and_weights} time={(time.time() - t0):.3f}')
Expand Down Expand Up @@ -412,7 +420,7 @@ def pad_to_same_length(pipe, embeds, empty_embedding_providers=None):
return embeds


def split_prompts(prompt, SD3 = False):
def split_prompts(pipe, prompt, SD3 = False):
if prompt.find("TE2:") != -1:
prompt, prompt2 = prompt.split("TE2:")
else:
Expand All @@ -430,23 +438,23 @@ def split_prompts(prompt, SD3 = False):
prompt3 = " " if prompt3.strip() == "" else prompt3.strip()

if SD3 and prompt3 != " ":
ps, _ws = get_prompts_with_weights(prompt3)
ps, _ws = get_prompts_with_weights(pipe, prompt3)
prompt3 = " ".join(ps)
return prompt, prompt2, prompt3


def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
device = devices.device
SD3 = hasattr(pipe, 'text_encoder_3')
prompt, prompt_2, prompt_3 = split_prompts(prompt, SD3)
neg_prompt, neg_prompt_2, neg_prompt_3 = split_prompts(neg_prompt, SD3)
prompt, prompt_2, prompt_3 = split_prompts(pipe, prompt, SD3)
neg_prompt, neg_prompt_2, neg_prompt_3 = split_prompts(pipe, neg_prompt, SD3)

if prompt != prompt_2:
ps = [get_prompts_with_weights(p) for p in [prompt, prompt_2]]
ns = [get_prompts_with_weights(p) for p in [neg_prompt, neg_prompt_2]]
ps = [get_prompts_with_weights(pipe, p) for p in [prompt, prompt_2]]
ns = [get_prompts_with_weights(pipe, p) for p in [neg_prompt, neg_prompt_2]]
else:
ps = 2 * [get_prompts_with_weights(prompt)]
ns = 2 * [get_prompts_with_weights(neg_prompt)]
ps = 2 * [get_prompts_with_weights(pipe, prompt)]
ns = 2 * [get_prompts_with_weights(pipe, neg_prompt)]

positives, positive_weights = zip(*ps)
negatives, negative_weights = zip(*ns)
Expand Down Expand Up @@ -561,8 +569,8 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c

def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
is_sd3 = hasattr(pipe, 'text_encoder_3')
prompt, prompt_2, _prompt_3 = split_prompts(prompt, is_sd3)
neg_prompt, neg_prompt_2, _neg_prompt_3 = split_prompts(neg_prompt, is_sd3)
prompt, prompt_2, _prompt_3 = split_prompts(pipe, prompt, is_sd3)
neg_prompt, neg_prompt_2, _neg_prompt_3 = split_prompts(pipe, neg_prompt, is_sd3)
try:
prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer)
neg_prompt = pipe.maybe_convert_prompt(neg_prompt, pipe.tokenizer)
Expand Down
4 changes: 4 additions & 0 deletions modules/pulid/pulid_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from diffusers import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -353,6 +354,9 @@ def __call__(
debug(f'PulID call: width={width} height={height} cfg={guidance_scale} steps={num_inference_steps} seed={seed} strength={strength} id_scale={id_scale} output={output_type}')
self.step = 0 # pylint: disable=attribute-defined-outside-init
self.callback_on_step_end = callback_on_step_end # pylint: disable=attribute-defined-outside-init
if isinstance(image, list) and len(image) > 0 and isinstance(image[0], Image.Image):
if image[0].width != width or image[0].height != height: # override width/height if different
width, height = image[0].width, image[0].height
size = (1, height, width)
# sigmas
sigmas = self.get_sigmas_karras(num_inference_steps).to(self.device)
Expand Down
2 changes: 1 addition & 1 deletion wiki
Submodule wiki updated from 352fc6 to 96f28b

0 comments on commit a0d55a5

Please sign in to comment.