Skip to content

Commit

Permalink
cleanup pulid
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Nov 3, 2024
1 parent 73a1de3 commit 363e7ba
Showing 1 changed file with 64 additions and 20 deletions.
84 changes: 64 additions & 20 deletions scripts/pulid_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
import gradio as gr
import numpy as np
from PIL import Image
from modules import shared, devices, scripts, processing, processing_helpers
from modules import shared, devices, errors, sd_models, scripts, processing, processing_helpers


pulid = None


class Script(scripts.Script):
def __init__(self):
self.images = []
super().__init__()
# self.register() # pulid is script with processing override so xyz doesnt execute

def title(self):
return 'PuLID'

Expand All @@ -24,8 +29,20 @@ def dependencies(self):
install('albumentations==1.4.3', 'albumentations', ignore=False, reinstall=True)
install('pydantic==1.10.15', 'pydantic', ignore=False, reinstall=True)

def register(self): # register xyz grid elements
def apply_field(field):
def fun(p, x, xs): # pylint: disable=unused-argument
setattr(p, field, x)
self.run(p)
return fun

import sys
xyz_classes = [v for k, v in sys.modules.items() if 'xyz_grid_classes' in k][0]
xyz_classes.axis_options.append(xyz_classes.AxisOption("[PuLID] Strength", float, apply_field("pulid_strength")))
xyz_classes.axis_options.append(xyz_classes.AxisOption("[PuLID] Zero", float, apply_field("pulid_zero")))

def load_images(self, files):
init_images = []
self.images = []
for file in files or []:
try:
if isinstance(file, str):
Expand All @@ -39,10 +56,10 @@ def load_images(self, files):
image = Image.open(file.name) # _TemporaryFileWrapper from gr.Files
else:
raise ValueError(f'IP adapter unknown input: {file}')
init_images.append(image)
self.images.append(image)
except Exception as e:
shared.log.warning(f'IP adapter failed to load image: {e}')
return gr.update(value=init_images, visible=len(init_images) > 0)
return gr.update(value=self.images, visible=len(self.images) > 0)

# return signature is array of gradio components
def ui(self, _is_img2img):
Expand All @@ -61,11 +78,13 @@ def ui(self, _is_img2img):
files.change(fn=self.load_images, inputs=[files], outputs=[gallery])
return [strength, zero, sampler, ortho, gallery]

def run(self, p: processing.StableDiffusionProcessing, strength, zero, sampler, ortho, gallery): # pylint: disable=arguments-differ
def run(self, p: processing.StableDiffusionProcessing, strength: float = 0.8, zero: int = 20, sampler: str = 'dpmpp_sde', ortho: str = 'v2', gallery: list = []): # pylint: disable=arguments-differ
global pulid # pylint: disable=global-statement
images = []
try:
images = [Image.open(f['name']) for f in gallery]
if len(gallery) == 0:
gallery = self.images
images = [Image.open(f['name']) for f in gallery if isinstance(f, dict)]
images = [np.array(image) for image in images]
except Exception as e:
shared.log.error(f'PuLID: failed to load images: {e}')
Expand All @@ -79,7 +98,11 @@ def run(self, p: processing.StableDiffusionProcessing, strength, zero, sampler,
return None
if pulid is None:
self.dependencies()
from modules import pulid # pylint: disable=redefined-outer-name
try:
from modules import pulid # pylint: disable=redefined-outer-name
except Exception as e:
shared.log.error(f'PuLID: failed to import library: {e}')
return None
# import os
# import importlib
# module_path = os.path.join(os.path.dirname(__file__), '..', 'pulid', '__init__.py')
Expand All @@ -92,16 +115,29 @@ def run(self, p: processing.StableDiffusionProcessing, strength, zero, sampler,
if p.batch_size > 1:
shared.log.warning('PuLID: batch size not supported')
p.batch_size = 1
strength = getattr(p, 'pulid_strength', strength)
zero = getattr(p, 'pulid_zero', zero)

processing.fix_seed(p)
pipe = None
if shared.sd_model_type == 'sdxl':
pipe = pulid.PuLIDPipelineXL(
pipe =shared.sd_model,
device=devices.device,
sampler=sampler,
cache_dir=shared.opts.hfcache_dir,
)
# TODO pulid has monolithic inference so not really working with offloading
sd_models.move_model(shared.sd_model, devices.device)
sd_models.move_model(shared.sd_model.vae, devices.device)
sd_models.move_model(shared.sd_model.unet, devices.device)
sd_models.move_model(shared.sd_model.text_encoder, devices.device)
sd_models.move_model(shared.sd_model.text_encoder_2, devices.device)
try:
pipe = pulid.PuLIDPipelineXL(
pipe =shared.sd_model,
device=devices.device,
sampler=sampler,
cache_dir=shared.opts.hfcache_dir,
)
except Exception as e:
shared.log.error(f'PuLID: failed to create pipeline: {e}')
errors.display(e, 'PuLID')
return None
if pipe is None:
return None
shared.state.begin('PuLID')
Expand Down Expand Up @@ -134,18 +170,26 @@ def run(self, p: processing.StableDiffusionProcessing, strength, zero, sampler,
with devices.inference_context():
uncond_id_embedding, id_embedding = pipe.get_id_embedding(images)
output = pipe.inference(prompt, (1, p.height, p.width), negative_prompt, id_embedding, uncond_id_embedding, strength, p.cfg_scale, p.steps, seed)[0]
outputs.append(output)
infotexts.append(processing.create_infotext(p))
seeds.append(seed)
prompts.append(prompt)
negative_prompts.append(negative_prompt)
if output is not None:
outputs.append(output)
infotexts.append(processing.create_infotext(p))
seeds.append(seed)
prompts.append(prompt)
negative_prompts.append(negative_prompt)

interim = [Image.fromarray(face) for face in pipe.debug_img_list]
t1 = time.time()
shared.log.debug(f'PuLID: output={output} interim={interim} time={t1-t0:.2f}')

p.extra_generation_params["PuLID"] = f'Strength={strength} Zero={zero} Ortho={ortho}'
processed = processing.Processed(p, outputs, infotexts=infotexts, all_seeds=seeds, all_prompts=prompts, all_negative_prompts=negative_prompts)
if len(outputs) > 0:
p.prompt = prompts[0]
p.negative_prompt = negative_prompts[0]
p.seed = seeds[0]
p.all_prompts = prompts
p.all_negative_prompts = negative_prompts
p.all_seeds = seeds
p.extra_generation_params["PuLID"] = f'Strength={strength} Zero={zero} Ortho={ortho}'
processed = processing.Processed(p, outputs, infotexts=infotexts)

shared.state.end('PuLID')
return processed

0 comments on commit 363e7ba

Please sign in to comment.