From 395a6329466597c55f599cce28c4694284f5d42b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 5 Jun 2023 23:16:08 -0500 Subject: [PATCH] fix(api): use VAE model dtype when converting sample --- api/onnx_web/diffusers/patches/unet.py | 2 -- api/onnx_web/diffusers/patches/vae.py | 26 +++++++++++--------------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index 089bf7207..9ad49cfc7 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -5,7 +5,6 @@ from diffusers import OnnxRuntimeModel from ...server import ServerContext -from .vae import set_vae_dtype logger = getLogger(__name__) @@ -37,7 +36,6 @@ def __call__( timestep.dtype, encoder_hidden_states.dtype, ) - set_vae_dtype(timestep.dtype) if self.prompt_embeds is not None: step_index = self.prompt_index % len(self.prompt_embeds) diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index 5ae6731c1..754fa4bba 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -6,6 +6,7 @@ from diffusers import OnnxRuntimeModel from diffusers.models.autoencoder_kl import AutoencoderKLOutput from diffusers.models.vae import DecoderOutput +from onnx.helper import tensor_dtype_to_np_dtype from ...server import ServerContext @@ -13,14 +14,6 @@ LATENT_CHANNELS = 4 -# TODO: does this need to change for fp16 modes? -timestep_dtype = np.float32 - - -def set_vae_dtype(dtype): - global timestep_dtype - timestep_dtype = dtype - class VAEWrapper(object): def __init__( @@ -46,7 +39,10 @@ def set_window_size(self, window: int, overlap: float): self.tile_overlap_factor = overlap def __call__(self, latent_sample=None, sample=None, **kwargs): - global timestep_dtype + # set timestep dtype to input type + inputs = self.wrapped.model.graph.input + sample_input = [i for i in inputs if i.name == "sample" or i.name == "latent_sample"][0] + sample_dtype = tensor_dtype_to_np_dtype(sample_input.type.tensor_type.elem_type) logger.trace( "VAE %s parameter types: %s, %s", @@ -55,13 +51,13 @@ def __call__(self, latent_sample=None, sample=None, **kwargs): (sample.dtype if sample is not None else "none"), ) - if latent_sample is not None and latent_sample.dtype != timestep_dtype: - logger.debug("converting VAE latent sample dtype") - latent_sample = latent_sample.astype(timestep_dtype) + if latent_sample is not None and latent_sample.dtype != sample_dtype: + logger.debug("converting VAE latent sample dtype to %s", sample_dtype) + latent_sample = latent_sample.astype(sample_dtype) - if sample is not None and sample.dtype != timestep_dtype: - logger.debug("converting VAE sample dtype") - sample = sample.astype(timestep_dtype) + if sample is not None and sample.dtype != sample_dtype: + logger.debug("converting VAE sample dtype to %s", sample_dtype) + sample = sample.astype(sample_dtype) if self.tiled: if self.decoder: