Skip to content

Commit

Permalink
fix(api): avoid circular deps in diffusion pipeline cache
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 29, 2023
1 parent 3359138 commit c7a6ec4
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 85 deletions.
4 changes: 3 additions & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
upscale_resrgan,
upscale_stable_diffusion,
)
from .diffusion import (
from .diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from logging import getLogger
from PIL import Image

from ..diffusion import (
from ..diffusion.load import (
load_pipeline,
)
from ..params import (
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
from typing import Callable, Tuple

from ..diffusion import (
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
Expand Down
7 changes: 5 additions & 2 deletions api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from logging import getLogger
from PIL import Image
from torchvision.transforms.functional import normalize

import torch

logger = getLogger(__name__)

pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}
Expand Down Expand Up @@ -41,7 +44,7 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
print(f'\tdetect {num_det_faces} faces')
logger.info('detect %s faces', num_det_faces)
# align and warp each face
face_helper.align_warp_face()

Expand All @@ -59,7 +62,7 @@ def correct_codeformer(image: Image.Image) -> Image.Image:
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
logger.error('Failed inference for CodeFormer: %s', error)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))

restored_face = restored_face.astype('uint8')
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from logging import getLogger
from PIL import Image

from ..diffusion import (
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
from typing import Callable, Tuple

from ..diffusion import (
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
)
Expand Down
81 changes: 81 additions & 0 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from diffusers import (
DiffusionPipeline,
)
from logging import getLogger
from typing import Any, Optional

from ..params import (
Size,
)

import gc
import numpy as np
import torch

logger = getLogger(__name__)

last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None


def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
'''
From https://www.travelneil.com/stable-diffusion-updates.html
'''
# 1 is batch size
latents_shape = (1, 4, size.height // 8, size.width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents


def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options

options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
logger.info('reusing existing diffusion pipeline')
pipe = last_pipeline_instance
else:
logger.info('unloading previous diffusion pipeline')
last_pipeline_instance = None
last_pipeline_scheduler = None
gc.collect()
torch.cuda.empty_cache()

logger.info('loading new diffusion pipeline')
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)

if device is not None:
pipe = pipe.to(device)

last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler

if last_pipeline_scheduler != scheduler:
logger.info('loading new diffusion scheduler')
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')

if device is not None:
scheduler = scheduler.to(device)

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler

logger.info('running garbage collection during pipeline change')
gc.collect()

return pipe


87 changes: 10 additions & 77 deletions api/onnx_web/diffusion.py → api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,39 @@
from diffusers import (
DiffusionPipeline,
# onnx
OnnxStableDiffusionPipeline,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger
from PIL import Image, ImageChops
from typing import Any, Optional
from typing import Any

from .chain import (
StageParams,
)
from .image import (
from ..image import (
expand_image,
)
from .params import (
from ..params import (
ImageParams,
Border,
Size,
StageParams,
)
from .upscale import (
from ..upscale import (
run_upscale_correction,
UpscaleParams,
)
from .utils import (
from ..utils import (
is_debug,
base_join,
ServerContext,
)
from .load import (
get_latents_from_seed,
load_pipeline,
)

import gc
import numpy as np
import torch

logger = getLogger(__name__)

last_pipeline_instance = None
last_pipeline_options = (None, None, None)
last_pipeline_scheduler = None


def get_latents_from_seed(seed: int, size: Size) -> np.ndarray:
'''
From https://www.travelneil.com/stable-diffusion-updates.html
'''
# 1 is batch size
latents_shape = (1, 4, size.height // 8, size.width // 8)
# Gotta use numpy instead of torch, because torch's randn() doesn't support DML
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
return image_latents


def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options

options = (pipeline, model, provider)
if last_pipeline_instance != None and last_pipeline_options == options:
logger.info('reusing existing diffusion pipeline')
pipe = last_pipeline_instance
else:
logger.info('unloading previous diffusion pipeline')
last_pipeline_instance = None
last_pipeline_scheduler = None
gc.collect()
torch.cuda.empty_cache()

logger.info('loading new diffusion pipeline')
pipe = pipeline.from_pretrained(
model,
provider=provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
)

if device is not None:
pipe = pipe.to(device)

last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler

if last_pipeline_scheduler != scheduler:
logger.info('loading new diffusion scheduler')
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')

if device is not None:
scheduler = scheduler.to(device)

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler

logger.info('running garbage collection during pipeline change')
gc.collect()

return pipe


def run_txt2img_pipeline(
ctx: ServerContext,
params: ImageParams,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
upscale_stable_diffusion,
ChainPipeline,
)
from .diffusion import (
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
Expand Down

0 comments on commit c7a6ec4

Please sign in to comment.