Skip to content

Commit

Permalink
feat: add panorama pipeline for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Sep 10, 2023
1 parent bea5a3c commit 0fa03e7
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 18 deletions.
33 changes: 19 additions & 14 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
from .pipelines.pix2pix import OnnxStableDiffusionInstructPix2PixPipeline
from .version_safe_diffusers import (
DDIMScheduler,
Expand Down Expand Up @@ -58,6 +59,7 @@
# "inpaint-sdxl": ORTStableDiffusionXLInpaintPipeline,
"lpw": OnnxStableDiffusionLongPromptWeightingPipeline,
"panorama": OnnxStableDiffusionPanoramaPipeline,
"panorama-sdxl": ORTStableDiffusionXLPanoramaPipeline,
"pix2pix": OnnxStableDiffusionInstructPix2PixPipeline,
"txt2img-sdxl": ORTStableDiffusionXLPipeline,
"txt2img": OnnxStableDiffusionPipeline,
Expand Down Expand Up @@ -399,7 +401,6 @@ def load_pipeline(
)

# make sure XL models are actually being used
# TODO: why is this needed?
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
Expand All @@ -424,23 +425,23 @@ def load_pipeline(
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
pipe.unet = None
run_gc([device])
pipe.unet = ORTModelUnet(unet_session, unet_model)

if not server.show_progress:
pipe.set_progress_bar_config(disable=True)

optimize_pipeline(server, pipe)

if not params.is_xl():
patch_pipeline(server, pipe, pipeline, pipeline_class, params)
patch_pipeline(server, pipe, pipeline_class, params)

server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])

if not params.is_xl() and hasattr(pipe, "vae_decoder"):
if hasattr(pipe, "vae_decoder"):
pipe.vae_decoder.set_tiled(tiled=params.tiled_vae)

if not params.is_xl() and hasattr(pipe, "vae_encoder"):
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_tiled(tiled=params.tiled_vae)

# update panorama params
Expand Down Expand Up @@ -514,17 +515,18 @@ def optimize_pipeline(
def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
pipe_type: str,
pipeline: Any,
params: ImageParams,
) -> None:
logger.debug("patching SD pipeline")

if pipe_type != "lpw":
if params.is_lpw():
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)

original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet)
if not params.is_xl():
original_unet = pipe.unet
pipe.unet = UNetWrapper(server, original_unet)
logger.debug("patched UNet with wrapper")

if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder
Expand All @@ -535,6 +537,9 @@ def patch_pipeline(
window=params.tiles,
overlap=params.overlap,
)
logger.debug("patched VAE decoder with wrapper")

if hasattr(pipe, "vae_encoder"):
original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper(
server,
Expand All @@ -543,7 +548,7 @@ def patch_pipeline(
window=params.tiles,
overlap=params.overlap,
)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:
logger.debug("no VAE found to patch")
logger.debug("patched VAE encoder with wrapper")

if hasattr(pipe, "vae"):
logger.warning("not patching single VAE, tiled VAE may not work")
4 changes: 3 additions & 1 deletion api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def set_window_size(self, window: int, overlap: float):
self.tile_overlap_factor = overlap

def __call__(self, latent_sample=None, sample=None, **kwargs):
model = self.wrapped.model if hasattr(self.wrapped, "model") else self.wrapped.session

# set timestep dtype to input type
sample_dtype = next(
(
input.type
for input in self.wrapped.model.get_inputs()
for input in model.get_inputs()
if input.name == "sample" or input.name == "latent_sample"
),
"tensor(float)",
Expand Down
Loading

0 comments on commit 0fa03e7

Please sign in to comment.