diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 90db776f6..1056b195a 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -127,8 +127,12 @@ def load_pipeline( # update panorama params if pipeline == "panorama": - cache_pipe.window = params.tiles // 8 - cache_pipe.stride = params.stride() // 8 + latent_window = params.tiles // 8 + latent_stride = params.stride() // 8 + + cache_pipe.set_window_size(latent_window, latent_stride) + cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride) + cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride) # update scheduler cache_scheduler = server.cache.get("scheduler", scheduler_key) diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index e290c0875..02c9b925c 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -34,12 +34,15 @@ def __init__( self.server = server self.wrapped = wrapped self.decoder = decoder - self.tiles = tiles + self.set_window_size(tiles, stride) + + def set_window_size(self, window: int, stride: int): + self.window = window self.stride = stride - self.tile_latent_min_size = tiles - self.tile_sample_min_size = tiles * 8 - self.tile_overlap_factor = stride / tiles + self.tile_latent_min_size = self.window + self.tile_sample_min_size = self.window * 8 + self.tile_overlap_factor = self.stride / self.window def __call__(self, latent_sample=None, sample=None, **kwargs): global timestep_dtype @@ -59,7 +62,7 @@ def __call__(self, latent_sample=None, sample=None, **kwargs): logger.debug("converting VAE sample dtype") sample = sample.astype(timestep_dtype) - if self.tiles is not None and self.stride is not None: + if self.window is not None and self.stride is not None: if self.decoder: return self.tiled_decode(latent_sample, **kwargs) else: diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index a9125d1c6..b40fbd092 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -1255,3 +1255,7 @@ def __call__( else: logger.debug("running txt2img panorama pipeline") return self.text2img(*args, **kwargs) + + def set_window_size(self, window: int, stride: int): + self.window = window + self.stride = stride