Skip to content

Commit

Permalink
apply lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 28, 2023
1 parent 79115e5 commit 1d9de2c
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 61 deletions.
1 change: 0 additions & 1 deletion api/onnx_web/diffusers/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from os import path
from typing import Any, List, Optional, Tuple

import numpy as np
from onnx import load_model
from transformers import CLIPTokenizer

Expand Down
7 changes: 3 additions & 4 deletions api/onnx_web/diffusers/patches/unet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from diffusers import OnnxRuntimeModel
from logging import getLogger
from typing import List, Optional

import numpy as np
from diffusers import OnnxRuntimeModel

from ...server import ServerContext
from .vae import set_vae_dtype

import numpy as np

logger = getLogger(__name__)


Expand Down Expand Up @@ -69,4 +69,3 @@ def set_prompts(self, prompt_embeds: List[np.ndarray]):
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0

46 changes: 34 additions & 12 deletions api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Union
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput
from logging import getLogger
from typing import Union

import numpy as np
import torch
from diffusers import OnnxRuntimeModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution

from ...server import ServerContext
from diffusers import OnnxRuntimeModel

logger = getLogger(__name__)

Expand Down Expand Up @@ -37,7 +37,11 @@ def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bo
def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype

logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype)
logger.trace(
"VAE %s parameter types: %s",
("decoder" if self.decoder else "encoder"),
latent_sample.dtype,
)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)
Expand All @@ -52,16 +56,22 @@ def __getattr__(self, attr):

def blend_v(self, a, b, blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
:, :, y, :
] * (y / blend_extent)
return b

def blend_h(self, a, b, blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
:, :, :, x
] * (x / blend_extent)
return b

@torch.no_grad()
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
def tiled_encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
Expand All @@ -84,7 +94,12 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
row.append(tile)
rows.append(row)
Expand All @@ -111,7 +126,9 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen
return AutoencoderKLOutput(latent_dist=posterior)

@torch.no_grad()
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def tiled_decode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
Expand All @@ -136,7 +153,12 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
row.append(decoded)
rows.append(row)
Expand All @@ -160,4 +182,4 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)
return DecoderOutput(sample=dec)
Loading

0 comments on commit 1d9de2c

Please sign in to comment.