Skip to content

Commit

Permalink
feat(api): add model cache for diffusion models
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 14, 2023
1 parent 7fa1783 commit e9472bc
Show file tree
Hide file tree
Showing 24 changed files with 111 additions and 66 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
noise_source_uniform,
)
from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams
from .upscale import run_upscale_correction
from .server.upscale import run_upscale_correction
from .utils import (
ServerContext,
base_join,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from PIL import Image

from ..device_pool import JobContext, ProgressCallback
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_order

Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/blend_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image

from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import load_pipeline
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext

logger = getLogger(__name__)


def blend_img2img(
job: JobContext,
_server: ServerContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
Expand All @@ -30,6 +30,7 @@ def blend_img2img(
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)

pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image

from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_order

Expand Down Expand Up @@ -65,6 +65,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):

latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from onnx_web.output import save_image

from ..device_pool import JobContext, ProgressCallback
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from PIL import Image

from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from gfpgan import GFPGANer
from PIL import Image

from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/persist_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from PIL import Image

from ..device_pool import JobContext
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/persist_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from boto3 import Session
from PIL import Image

from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/reduce_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from PIL import Image

from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/reduce_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from PIL import Image

from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/source_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from PIL import Image

from ..device_pool import JobContext
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext
from ..utils import ServerContext

logger = getLogger(__name__)
Expand Down
5 changes: 3 additions & 2 deletions api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image

from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, load_pipeline
from ..params import ImageParams, Size, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext

logger = getLogger(__name__)
Expand All @@ -16,7 +16,7 @@
def source_txt2img(
job: JobContext,
server: ServerContext,
stage: StageParams,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
Expand All @@ -35,6 +35,7 @@ def source_txt2img(

latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
params.model,
params.scheduler,
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from diffusers import OnnxStableDiffusionInpaintPipeline
from PIL import Image, ImageDraw

from ..device_pool import JobContext, ProgressCallback
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid, process_tile_order

Expand Down Expand Up @@ -73,6 +73,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):

latents = get_tile_latents(full_latents, dims)
pipe = load_pipeline(
server,
OnnxStableDiffusionInpaintPipeline,
params.model,
params.scheduler,
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

from ..device_pool import JobContext
from ..onnx import OnnxNet
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext
from ..utils import ServerContext, run_gc

logger = getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from diffusers import StableDiffusionUpscalePipeline
from PIL import Image

from ..device_pool import JobContext, ProgressCallback
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server.device_pool import JobContext, ProgressCallback
from ..utils import ServerContext, run_gc

logger = getLogger(__name__)
Expand Down
67 changes: 29 additions & 38 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,10 @@
)

from ..params import DeviceParams, Size
from ..utils import run_gc
from ..utils import ServerContext, run_gc

logger = getLogger(__name__)

last_pipeline_instance: Any = None
last_pipeline_options: Tuple[
Optional[DiffusionPipeline],
Optional[str],
Optional[str],
Optional[str],
Optional[bool],
] = (None, None, None, None, None)
last_pipeline_scheduler: Any = None

latent_channels = 4
latent_factor = 8

Expand Down Expand Up @@ -90,24 +80,42 @@ def get_tile_latents(


def load_pipeline(
server: ServerContext,
pipeline: DiffusionPipeline,
model: str,
scheduler_type: Any,
device: DeviceParams,
lpw: bool,
):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
pipe_key = (pipeline, model, device.device, device.provider, lpw)
scheduler_key = (scheduler_type,)

cache_pipe = server.cache.get("diffusion", pipe_key)

options = (pipeline, model, device.device, device.provider, lpw)
if last_pipeline_instance is not None and last_pipeline_options == options:
if cache_pipe is not None:
logger.debug("reusing existing diffusion pipeline")
pipe = last_pipeline_instance
pipe = cache_pipe

cache_scheduler = server.cache.get("scheduler", scheduler_key)
if cache_scheduler is None:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)

if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())

pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
run_gc()

else:
logger.debug("unloading previous diffusion pipeline")
last_pipeline_instance = None
last_pipeline_scheduler = None
server.cache.drop("diffusion", pipe_key)
run_gc()

if lpw:
Expand Down Expand Up @@ -135,24 +143,7 @@ def load_pipeline(
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_device())

last_pipeline_instance = pipe
last_pipeline_options = options
last_pipeline_scheduler = scheduler_type

if last_pipeline_scheduler != scheduler_type:
logger.debug("loading new diffusion scheduler")
scheduler = scheduler_type.from_pretrained(
model,
provider=device.provider,
provider_options=device.options,
subfolder="scheduler",
)

if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())

pipe.scheduler = scheduler
last_pipeline_scheduler = scheduler_type
run_gc()
server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, scheduler)

return pipe
6 changes: 4 additions & 2 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from onnx_web.chain.base import ChainProgress

from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
from ..params import Border, ImageParams, Size, StageParams
from ..upscale import UpscaleParams, run_upscale_correction
from ..server.device_pool import JobContext
from ..server.upscale import UpscaleParams, run_upscale_correction
from ..utils import ServerContext, run_gc
from .load import get_latents_from_seed, load_pipeline

Expand All @@ -30,6 +30,7 @@ def run_txt2img_pipeline(
) -> None:
latents = get_latents_from_seed(params.seed, size)
pipe = load_pipeline(
server,
OnnxStableDiffusionPipeline,
params.model,
params.scheduler,
Expand Down Expand Up @@ -97,6 +98,7 @@ def run_img2img_pipeline(
strength: float,
) -> None:
pipe = load_pipeline(
server,
OnnxStableDiffusionImg2ImgPipeline,
params.model,
params.scheduler,
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
upscale_resrgan,
upscale_stable_diffusion,
)
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
Expand All @@ -40,7 +39,6 @@
run_txt2img_pipeline,
run_upscale_pipeline,
)
from .hacks import apply_patches
from .image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
Expand All @@ -62,6 +60,8 @@
TileOrder,
UpscaleParams,
)
from .server.device_pool import DevicePoolExecutor
from .server.hacks import apply_patches
from .utils import (
ServerContext,
base_join,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from traceback import format_exception
from typing import Any, Callable, List, Optional, Tuple, Union

from .params import DeviceParams
from .utils import run_gc
from ..params import DeviceParams
from ..utils import run_gc

logger = getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/hacks.py → api/onnx_web/server/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import basicsr.utils.download_util
import codeformer.facelib.utils.misc

from .utils import ServerContext
from ..utils import ServerContext

logger = getLogger(__name__)

Expand Down
Loading

0 comments on commit e9472bc

Please sign in to comment.