From 9c5043e9d0e516b06dc3de989826ebc7cedd5e0b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Feb 2023 09:33:13 -0600 Subject: [PATCH] fix(api): correctly cache diffusers scheduler --- api/onnx_web/diffusion/load.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 761db0a01..9256f1aa2 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -55,7 +55,7 @@ def get_tile_latents( def load_pipeline( pipeline: DiffusionPipeline, model: str, - scheduler: Any, + scheduler_type: Any, device: DeviceParams, lpw: bool, ): @@ -79,7 +79,7 @@ def load_pipeline( custom_pipeline = None logger.debug("loading new diffusion pipeline from %s", model) - scheduler = scheduler.from_pretrained( + scheduler = scheduler_type.from_pretrained( model, provider=device.provider, provider_options=device.options, @@ -100,11 +100,11 @@ def load_pipeline( last_pipeline_instance = pipe last_pipeline_options = options - last_pipeline_scheduler = scheduler + last_pipeline_scheduler = scheduler_type if last_pipeline_scheduler != scheduler: logger.debug("loading new diffusion scheduler") - scheduler = scheduler.from_pretrained( + scheduler = scheduler_type.from_pretrained( model, provider=device.provider, provider_options=device.options, @@ -112,10 +112,10 @@ def load_pipeline( ) if device is not None and hasattr(scheduler, "to"): - scheduler = scheduler.to(device) + scheduler = scheduler.to(device.torch_device()) pipe.scheduler = scheduler - last_pipeline_scheduler = scheduler + last_pipeline_scheduler = scheduler_type run_gc() return pipe