diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 90f3bfe5f..73b0ebc1a 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -11,6 +11,7 @@ UpscaleParams, ) from ..utils import ( + run_gc, ServerContext, ) from .upscale_resrgan import ( @@ -51,6 +52,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[ last_pipeline_instance = gfpgan last_pipeline_params = face_path + run_gc() return gfpgan diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index bc9ec56d4..ee58733d7 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -13,6 +13,7 @@ UpscaleParams, ) from ..utils import ( + run_gc, ServerContext, ) @@ -66,6 +67,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): last_pipeline_instance = upsampler last_pipeline_params = cache_params + run_gc() return upsampler diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 36ed72943..1f3f722f6 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -14,6 +14,7 @@ UpscaleParams, ) from ..utils import ( + run_gc, ServerContext, ) @@ -44,6 +45,7 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): last_pipeline_instance = pipeline last_pipeline_params = cache_params + run_gc() return pipeline diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 2f8a71298..c8e457e4a 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -7,10 +7,12 @@ from ..params import ( Size, ) +from ..utils import ( + run_gc, +) import gc import numpy as np -import torch logger = getLogger(__name__) @@ -39,7 +41,7 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np xt = x + t yt = y + t - return full_latents[:,:,y:yt,x:xt] + return full_latents[:, :, y:yt, x:xt] def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None): @@ -55,8 +57,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu logger.info('unloading previous diffusion pipeline') last_pipeline_instance = None last_pipeline_scheduler = None - gc.collect() - torch.cuda.empty_cache() + run_gc() logger.info('loading new diffusion pipeline') pipe = pipeline.from_pretrained( @@ -83,10 +84,6 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu pipe.scheduler = scheduler last_pipeline_scheduler = scheduler - - logger.info('running garbage collection during pipeline change') - gc.collect() + run_gc() return pipe - - diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 7145e97f5..80c79810e 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -1,7 +1,6 @@ from diffusers import ( OnnxStableDiffusionPipeline, OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, ) from logging import getLogger from PIL import Image, ImageChops @@ -10,9 +9,6 @@ from ..chain import ( upscale_outpaint, ) -from ..image import ( - expand_image, -) from ..params import ( ImageParams, Border, @@ -24,19 +20,20 @@ UpscaleParams, ) from ..utils import ( - is_debug, base_join, + run_gc, ServerContext, ) from .load import ( - get_latents_from_seed, - load_pipeline, + get_latents_from_seed, + load_pipeline, ) import numpy as np logger = getLogger(__name__) + def run_txt2img_pipeline( ctx: ServerContext, params: ImageParams, @@ -69,6 +66,7 @@ def run_txt2img_pipeline( del image del result + run_gc() logger.info('saved txt2img output: %s', dest) @@ -104,6 +102,7 @@ def run_img2img_pipeline( del image del result + run_gc() logger.info('saved img2img output: %s', dest) @@ -139,7 +138,8 @@ def run_inpaint_pipeline( if image.size == source_image.size: image = ImageChops.blend(source_image, image, strength) else: - logger.info('output image size does not match source, skipping post-blend') + logger.info( + 'output image size does not match source, skipping post-blend') image = run_upscale_correction( ctx, stage, params, image, upscale=upscale) @@ -148,6 +148,7 @@ def run_inpaint_pipeline( image.save(dest) del image + run_gc() logger.info('saved inpaint output: %s', dest) @@ -167,5 +168,6 @@ def run_upscale_pipeline( image.save(dest) del image + run_gc() logger.info('saved img2img output: %s', dest) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 3e15c7ccf..2090b0854 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -5,6 +5,9 @@ from time import time from typing import Any, Dict, List, Optional, Union, Tuple +import gc +import torch + from .params import ( ImageParams, Param, @@ -158,3 +161,9 @@ def make_output_name( def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join('/', tail)), '/') return path.join(base, tail_path) + + +def run_gc(): + logger.debug('running garbage collection') + gc.collect() + torch.cuda.empty_cache() \ No newline at end of file