Skip to content

Commit

Permalink
feat(api): add tiled VAE wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 28, 2023
1 parent eef85aa commit 64a753e
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 97 deletions.
90 changes: 3 additions & 87 deletions api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ..params import DeviceParams
from ..server import ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
Expand Down Expand Up @@ -397,92 +399,6 @@ def optimize_pipeline(
logger.warning("error while enabling memory efficient attention: %s", e)


# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32


class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel

def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
):
self.server = server
self.wrapped = wrapped

def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
**kwargs,
):
global timestep_dtype
timestep_dtype = timestep.dtype

logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)

if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
logger.trace("multiple prompt embeds found, using step: %s", step_index)
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1

if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)

return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)

def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0


class VAEWrapper(object):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped

def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype

logger.trace("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)

return self.wrapped(latent_sample=latent_sample, **kwargs)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)


def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
Expand All @@ -496,7 +412,7 @@ def patch_pipeline(

if hasattr(pipe, "vae_decoder"):
original_vae = pipe.vae_decoder
pipe.vae_decoder = VAEWrapper(server, original_vae)
pipe.vae_decoder = VAEWrapper(server, original_vae, decoder=True)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:
Expand Down
72 changes: 72 additions & 0 deletions api/onnx_web/diffusers/patches/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from diffusers import OnnxRuntimeModel
from logging import getLogger
from typing import List, Optional

from ...server import ServerContext
from .vae import set_vae_dtype

import numpy as np

logger = getLogger(__name__)


class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel

def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
):
self.server = server
self.wrapped = wrapped

def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
**kwargs,
):
logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
set_vae_dtype(timestep.dtype)

if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
logger.trace("multiple prompt embeds found, using step: %s", step_index)
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1

if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)

if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)

return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)

def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0

165 changes: 165 additions & 0 deletions api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Union
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn

from ...server import ServerContext
from diffusers import OnnxRuntimeModel

logger = getLogger(__name__)

LATENT_CHANNELS = 4
SAMPLE_SIZE = 32

# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32


def set_vae_dtype(dtype):
global timestep_dtype
timestep_dtype = dtype


class VAEWrapper(object):
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
self.server = server
self.wrapped = wrapped
self.decoder = decoder

self.tile_sample_min_size = SAMPLE_SIZE
self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1)

def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype

logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)

if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(latent_sample, **kwargs)

def __getattr__(self, attr):
return getattr(self.wrapped, attr)

def blend_v(self, a, b, blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b

def blend_h(self, a, b, blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)

overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent

# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
posterior = posterior.numpy()

if not return_dict:
return (posterior,)

return AutoencoderKLOutput(latent_dist=posterior)

def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
if isinstance(z, np.ndarray):
z = torch.from_numpy(z)

overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent

# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))

dec = torch.cat(result_rows, dim=2)
dec = dec.numpy()

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)
2 changes: 1 addition & 1 deletion api/onnx_web/diffusers/pipelines/panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
Expand Down
Loading

0 comments on commit 64a753e

Please sign in to comment.