Skip to content

Commit

Permalink
remove conv breaking VAE wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 28, 2023
1 parent 7e5380d commit 79115e5
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions api/onnx_web/diffusers/patches/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import torch
import torch.nn as nn

from ...server import ServerContext
from diffusers import OnnxRuntimeModel
Expand Down Expand Up @@ -35,9 +34,6 @@ def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bo
self.tile_latent_min_size = LATENT_SIZE
self.tile_overlap_factor = 0.25

self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1)

def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype

Expand Down Expand Up @@ -90,7 +86,6 @@ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> Autoen
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
Expand Down Expand Up @@ -142,7 +137,6 @@ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
row.append(decoded)
rows.append(row)
Expand Down

0 comments on commit 79115e5

Please sign in to comment.