-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
251 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from diffusers import OnnxRuntimeModel | ||
from logging import getLogger | ||
from typing import List, Optional | ||
|
||
from ...server import ServerContext | ||
from .vae import set_vae_dtype | ||
|
||
import numpy as np | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
class UNetWrapper(object): | ||
prompt_embeds: Optional[List[np.ndarray]] = None | ||
prompt_index: int = 0 | ||
server: ServerContext | ||
wrapped: OnnxRuntimeModel | ||
|
||
def __init__( | ||
self, | ||
server: ServerContext, | ||
wrapped: OnnxRuntimeModel, | ||
): | ||
self.server = server | ||
self.wrapped = wrapped | ||
|
||
def __call__( | ||
self, | ||
sample: np.ndarray = None, | ||
timestep: np.ndarray = None, | ||
encoder_hidden_states: np.ndarray = None, | ||
**kwargs, | ||
): | ||
logger.trace( | ||
"UNet parameter types: %s, %s, %s", | ||
sample.dtype, | ||
timestep.dtype, | ||
encoder_hidden_states.dtype, | ||
) | ||
set_vae_dtype(timestep.dtype) | ||
|
||
if self.prompt_embeds is not None: | ||
step_index = self.prompt_index % len(self.prompt_embeds) | ||
logger.trace("multiple prompt embeds found, using step: %s", step_index) | ||
encoder_hidden_states = self.prompt_embeds[step_index] | ||
self.prompt_index += 1 | ||
|
||
if sample.dtype != timestep.dtype: | ||
logger.trace("converting UNet sample to timestep dtype") | ||
sample = sample.astype(timestep.dtype) | ||
|
||
if encoder_hidden_states.dtype != timestep.dtype: | ||
logger.trace("converting UNet hidden states to timestep dtype") | ||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) | ||
|
||
return self.wrapped( | ||
sample=sample, | ||
timestep=timestep, | ||
encoder_hidden_states=encoder_hidden_states, | ||
**kwargs, | ||
) | ||
|
||
def __getattr__(self, attr): | ||
return getattr(self.wrapped, attr) | ||
|
||
def set_prompts(self, prompt_embeds: List[np.ndarray]): | ||
logger.debug( | ||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds] | ||
) | ||
self.prompt_embeds = prompt_embeds | ||
self.prompt_index = 0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from typing import Union | ||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput | ||
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput | ||
from logging import getLogger | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
from ...server import ServerContext | ||
from diffusers import OnnxRuntimeModel | ||
|
||
logger = getLogger(__name__) | ||
|
||
LATENT_CHANNELS = 4 | ||
SAMPLE_SIZE = 32 | ||
|
||
# TODO: does this need to change for fp16 modes? | ||
timestep_dtype = np.float32 | ||
|
||
|
||
def set_vae_dtype(dtype): | ||
global timestep_dtype | ||
timestep_dtype = dtype | ||
|
||
|
||
class VAEWrapper(object): | ||
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool): | ||
self.server = server | ||
self.wrapped = wrapped | ||
self.decoder = decoder | ||
|
||
self.tile_sample_min_size = SAMPLE_SIZE | ||
self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1))) | ||
self.tile_overlap_factor = 0.25 | ||
|
||
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1) | ||
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1) | ||
|
||
def __call__(self, latent_sample=None, **kwargs): | ||
global timestep_dtype | ||
|
||
logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype) | ||
if latent_sample.dtype != timestep_dtype: | ||
logger.info("converting VAE sample dtype") | ||
latent_sample = latent_sample.astype(timestep_dtype) | ||
|
||
if self.decoder: | ||
return self.tiled_decode(latent_sample, **kwargs) | ||
else: | ||
return self.tiled_encode(latent_sample, **kwargs) | ||
|
||
def __getattr__(self, attr): | ||
return getattr(self.wrapped, attr) | ||
|
||
def blend_v(self, a, b, blend_extent): | ||
for y in range(min(a.shape[2], b.shape[2], blend_extent)): | ||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) | ||
return b | ||
|
||
def blend_h(self, a, b, blend_extent): | ||
for x in range(min(a.shape[3], b.shape[3], blend_extent)): | ||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) | ||
return b | ||
|
||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: | ||
r"""Encode a batch of images using a tiled encoder. | ||
Args: | ||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several | ||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: | ||
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the | ||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the | ||
look of the output, but they should be much less noticeable. | ||
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. | ||
""" | ||
if isinstance(x, np.ndarray): | ||
x = torch.from_numpy(x) | ||
|
||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) | ||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) | ||
row_limit = self.tile_latent_min_size - blend_extent | ||
|
||
# Split the image into 512x512 tiles and encode them separately. | ||
rows = [] | ||
for i in range(0, x.shape[2], overlap_size): | ||
row = [] | ||
for j in range(0, x.shape[3], overlap_size): | ||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] | ||
tile = self(tile) | ||
tile = self.quant_conv(tile) | ||
row.append(tile) | ||
rows.append(row) | ||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_extent) | ||
result_row.append(tile[:, :, :row_limit, :row_limit]) | ||
result_rows.append(torch.cat(result_row, dim=3)) | ||
|
||
moments = torch.cat(result_rows, dim=2) | ||
posterior = DiagonalGaussianDistribution(moments) | ||
posterior = posterior.numpy() | ||
|
||
if not return_dict: | ||
return (posterior,) | ||
|
||
return AutoencoderKLOutput(latent_dist=posterior) | ||
|
||
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | ||
r"""Decode a batch of images using a tiled decoder. | ||
Args: | ||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several | ||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: | ||
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the | ||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the | ||
look of the output, but they should be much less noticeable. | ||
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to | ||
`True`): | ||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. | ||
""" | ||
if isinstance(z, np.ndarray): | ||
z = torch.from_numpy(z) | ||
|
||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) | ||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) | ||
row_limit = self.tile_sample_min_size - blend_extent | ||
|
||
# Split z into overlapping 64x64 tiles and decode them separately. | ||
# The tiles have an overlap to avoid seams between tiles. | ||
rows = [] | ||
for i in range(0, z.shape[2], overlap_size): | ||
row = [] | ||
for j in range(0, z.shape[3], overlap_size): | ||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] | ||
tile = self.post_quant_conv(tile) | ||
decoded = self(tile) | ||
row.append(decoded) | ||
rows.append(row) | ||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_extent) | ||
result_row.append(tile[:, :, :row_limit, :row_limit]) | ||
result_rows.append(torch.cat(result_row, dim=3)) | ||
|
||
dec = torch.cat(result_rows, dim=2) | ||
dec = dec.numpy() | ||
|
||
if not return_dict: | ||
return (dec,) | ||
|
||
return DecoderOutput(sample=dec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.