diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 996d4cc6b..eb1395c67 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -5,9 +5,11 @@ upscale_resrgan, upscale_stable_diffusion, ) -from .diffusion import ( +from .diffusion.load import ( get_latents_from_seed, load_pipeline, +) +from .diffusion.run import ( run_img2img_pipeline, run_inpaint_pipeline, run_txt2img_pipeline, diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 10b35a427..3bd24933c 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -4,7 +4,7 @@ from logging import getLogger from PIL import Image -from ..diffusion import ( +from ..diffusion.load import ( load_pipeline, ) from ..params import ( diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 2891b55c6..bdf427802 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -5,7 +5,7 @@ from PIL import Image from typing import Callable, Tuple -from ..diffusion import ( +from ..diffusion.load import ( get_latents_from_seed, load_pipeline, ) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 8044092ac..5eae5b122 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,11 +1,14 @@ from basicsr.utils import img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from logging import getLogger from PIL import Image from torchvision.transforms.functional import normalize import torch +logger = getLogger(__name__) + pretrain_model_url = { 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', } @@ -41,7 +44,7 @@ def correct_codeformer(image: Image.Image) -> Image.Image: # get face landmarks for each face num_det_faces = face_helper.get_face_landmarks_5( only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) - print(f'\tdetect {num_det_faces} faces') + logger.info('detect %s faces', num_det_faces) # align and warp each face face_helper.align_warp_face() @@ -59,7 +62,7 @@ def correct_codeformer(image: Image.Image) -> Image.Image: del output torch.cuda.empty_cache() except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}') + logger.error('Failed inference for CodeFormer: %s', error) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 08765bec7..88d495901 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -4,7 +4,7 @@ from logging import getLogger from PIL import Image -from ..diffusion import ( +from ..diffusion.load import ( get_latents_from_seed, load_pipeline, ) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 250234740..bbb4ae30c 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -5,7 +5,7 @@ from PIL import Image from typing import Callable, Tuple -from ..diffusion import ( +from ..diffusion.load import ( get_latents_from_seed, load_pipeline, ) diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py new file mode 100644 index 000000000..2e375551f --- /dev/null +++ b/api/onnx_web/diffusion/load.py @@ -0,0 +1,81 @@ +from diffusers import ( + DiffusionPipeline, +) +from logging import getLogger +from typing import Any, Optional + +from ..params import ( + Size, +) + +import gc +import numpy as np +import torch + +logger = getLogger(__name__) + +last_pipeline_instance = None +last_pipeline_options = (None, None, None) +last_pipeline_scheduler = None + + +def get_latents_from_seed(seed: int, size: Size) -> np.ndarray: + ''' + From https://www.travelneil.com/stable-diffusion-updates.html + ''' + # 1 is batch size + latents_shape = (1, 4, size.height // 8, size.width // 8) + # Gotta use numpy instead of torch, because torch's randn() doesn't support DML + rng = np.random.default_rng(seed) + image_latents = rng.standard_normal(latents_shape).astype(np.float32) + return image_latents + + +def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None): + global last_pipeline_instance + global last_pipeline_scheduler + global last_pipeline_options + + options = (pipeline, model, provider) + if last_pipeline_instance != None and last_pipeline_options == options: + logger.info('reusing existing diffusion pipeline') + pipe = last_pipeline_instance + else: + logger.info('unloading previous diffusion pipeline') + last_pipeline_instance = None + last_pipeline_scheduler = None + gc.collect() + torch.cuda.empty_cache() + + logger.info('loading new diffusion pipeline') + pipe = pipeline.from_pretrained( + model, + provider=provider, + safety_checker=None, + scheduler=scheduler.from_pretrained(model, subfolder='scheduler') + ) + + if device is not None: + pipe = pipe.to(device) + + last_pipeline_instance = pipe + last_pipeline_options = options + last_pipeline_scheduler = scheduler + + if last_pipeline_scheduler != scheduler: + logger.info('loading new diffusion scheduler') + scheduler = scheduler.from_pretrained( + model, subfolder='scheduler') + + if device is not None: + scheduler = scheduler.to(device) + + pipe.scheduler = scheduler + last_pipeline_scheduler = scheduler + + logger.info('running garbage collection during pipeline change') + gc.collect() + + return pipe + + diff --git a/api/onnx_web/diffusion.py b/api/onnx_web/diffusion/run.py similarity index 66% rename from api/onnx_web/diffusion.py rename to api/onnx_web/diffusion/run.py index d1598d1a1..ee4089ac0 100644 --- a/api/onnx_web/diffusion.py +++ b/api/onnx_web/diffusion/run.py @@ -1,106 +1,39 @@ from diffusers import ( - DiffusionPipeline, - # onnx OnnxStableDiffusionPipeline, OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, ) from logging import getLogger from PIL import Image, ImageChops -from typing import Any, Optional +from typing import Any -from .chain import ( - StageParams, -) -from .image import ( +from ..image import ( expand_image, ) -from .params import ( +from ..params import ( ImageParams, Border, Size, + StageParams, ) -from .upscale import ( +from ..upscale import ( run_upscale_correction, UpscaleParams, ) -from .utils import ( +from ..utils import ( is_debug, base_join, ServerContext, ) +from .load import ( + get_latents_from_seed, + load_pipeline, +) -import gc import numpy as np -import torch logger = getLogger(__name__) -last_pipeline_instance = None -last_pipeline_options = (None, None, None) -last_pipeline_scheduler = None - - -def get_latents_from_seed(seed: int, size: Size) -> np.ndarray: - ''' - From https://www.travelneil.com/stable-diffusion-updates.html - ''' - # 1 is batch size - latents_shape = (1, 4, size.height // 8, size.width // 8) - # Gotta use numpy instead of torch, because torch's randn() doesn't support DML - rng = np.random.default_rng(seed) - image_latents = rng.standard_normal(latents_shape).astype(np.float32) - return image_latents - - -def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None): - global last_pipeline_instance - global last_pipeline_scheduler - global last_pipeline_options - - options = (pipeline, model, provider) - if last_pipeline_instance != None and last_pipeline_options == options: - logger.info('reusing existing diffusion pipeline') - pipe = last_pipeline_instance - else: - logger.info('unloading previous diffusion pipeline') - last_pipeline_instance = None - last_pipeline_scheduler = None - gc.collect() - torch.cuda.empty_cache() - - logger.info('loading new diffusion pipeline') - pipe = pipeline.from_pretrained( - model, - provider=provider, - safety_checker=None, - scheduler=scheduler.from_pretrained(model, subfolder='scheduler') - ) - - if device is not None: - pipe = pipe.to(device) - - last_pipeline_instance = pipe - last_pipeline_options = options - last_pipeline_scheduler = scheduler - - if last_pipeline_scheduler != scheduler: - logger.info('loading new diffusion scheduler') - scheduler = scheduler.from_pretrained( - model, subfolder='scheduler') - - if device is not None: - scheduler = scheduler.to(device) - - pipe.scheduler = scheduler - last_pipeline_scheduler = scheduler - - logger.info('running garbage collection during pipeline change') - gc.collect() - - return pipe - - def run_txt2img_pipeline( ctx: ServerContext, params: ImageParams, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index f8eb1327b..34833ce48 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -35,7 +35,7 @@ upscale_stable_diffusion, ChainPipeline, ) -from .diffusion import ( +from .diffusion.run import ( run_img2img_pipeline, run_inpaint_pipeline, run_txt2img_pipeline,