diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 208d2ffa9..beca62dd0 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -17,7 +17,7 @@ run_upscale_pipeline, ) from .diffusers.stub_scheduler import StubScheduler -from .diffusers.upscale import run_upscale_correction +from .diffusers.upscale import append_upscale_correction from .image.utils import ( expand_image, valid_image, diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 61e7f5e0c..e7858bd20 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,10 +1,8 @@ from .base import ChainPipeline, PipelineStage, StageCallback, StageParams -from .blend_controlnet import blend_controlnet from .blend_img2img import blend_img2img from .blend_inpaint import blend_inpaint from .blend_linear import blend_linear from .blend_mask import blend_mask -from .blend_pix2pix import blend_pix2pix from .correct_codeformer import correct_codeformer from .correct_gfpgan import correct_gfpgan from .persist_disk import persist_disk @@ -16,18 +14,17 @@ from .source_txt2img import source_txt2img from .source_url import source_url from .upscale_bsrgan import upscale_bsrgan +from .upscale_highres import upscale_highres from .upscale_outpaint import upscale_outpaint from .upscale_resrgan import upscale_resrgan from .upscale_stable_diffusion import upscale_stable_diffusion from .upscale_swinir import upscale_swinir CHAIN_STAGES = { - "blend-controlnet": blend_controlnet, "blend-img2img": blend_img2img, "blend-inpaint": blend_inpaint, "blend-linear": blend_linear, "blend-mask": blend_mask, - "blend-pix2pix": blend_pix2pix, "correct-codeformer": correct_codeformer, "correct-gfpgan": correct_gfpgan, "persist-disk": persist_disk, @@ -39,6 +36,7 @@ "source-txt2img": source_txt2img, "source-url": source_url, "upscale-bsrgan": upscale_bsrgan, + "upscale-highres": upscale_highres, "upscale-outpaint": upscale_outpaint, "upscale-resrgan": upscale_resrgan, "upscale-stable-diffusion": upscale_stable_diffusion, diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index cc9fd56b8..112b64bfe 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -88,7 +88,7 @@ def __call__( job: WorkerContext, server: ServerContext, params: ImageParams, - source: Image.Image, + source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **pipeline_kwargs ) -> Image.Image: diff --git a/api/onnx_web/chain/blend_controlnet.py b/api/onnx_web/chain/blend_controlnet.py deleted file mode 100644 index c944b767d..000000000 --- a/api/onnx_web/chain/blend_controlnet.py +++ /dev/null @@ -1,54 +0,0 @@ -from logging import getLogger -from typing import Optional - -import numpy as np -from PIL import Image - -from ..diffusers.load import load_pipeline -from ..params import ImageParams, StageParams -from ..server import ServerContext -from ..worker import ProgressCallback, WorkerContext - -logger = getLogger(__name__) - - -def blend_controlnet( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - callback: Optional[ProgressCallback] = None, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - source = stage_source or source - logger.info( - "blending image using ControlNet, %s steps: %s", params.steps, params.prompt - ) - - pipe = load_pipeline( - server, - params, - "controlnet", - job.get_device(), - ) - - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - strength=params.strength, # TODO: ControlNet strength - callback=callback, - ) - - output = result.images[0] - - logger.info("final output image size: %sx%s", output.width, output.height) - return output diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index c825b9e54..c2272dc02 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -6,6 +6,7 @@ from PIL import Image from ..diffusers.load import load_pipeline +from ..diffusers.utils import encode_prompt, parse_prompt from ..params import ImageParams, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -20,8 +21,9 @@ def blend_img2img( params: ImageParams, source: Image.Image, *, + strength: float, callback: Optional[ProgressCallback] = None, - stage_source: Image.Image, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) @@ -30,14 +32,28 @@ def blend_img2img( "blending image using img2img, %s steps: %s", params.steps, params.prompt ) - pipe_type = "lpw" if params.lpw() else "img2img" + prompt_pairs, loras, inversions = parse_prompt(params) + + pipe_type = params.get_valid_pipeline("img2img") pipe = load_pipeline( server, params, pipe_type, job.get_device(), - # TODO: add LoRAs and TIs + inversions=inversions, + loras=loras, ) + + pipe_params = {} + if pipe_type == "controlnet": + pipe_params["controlnet_conditioning_scale"] = strength + elif pipe_type == "img2img": + pipe_params["strength"] = strength + elif pipe_type == "panorama": + pipe_params["strength"] = strength + elif pipe_type == "pix2pix": + pipe_params["image_guidance_scale"] = strength + if params.lpw(): logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) @@ -50,8 +66,13 @@ def blend_img2img( num_inference_steps=params.steps, strength=params.strength, callback=callback, + **pipe_params, ) else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg()) + pipe.unet.set_prompts(prompt_embeds) + rng = np.random.RandomState(params.seed) result = pipe( params.prompt, @@ -62,6 +83,7 @@ def blend_img2img( num_inference_steps=params.steps, strength=params.strength, callback=callback, + **pipe_params, ) output = result.images[0] diff --git a/api/onnx_web/chain/blend_pix2pix.py b/api/onnx_web/chain/blend_pix2pix.py deleted file mode 100644 index a4a2c6143..000000000 --- a/api/onnx_web/chain/blend_pix2pix.py +++ /dev/null @@ -1,71 +0,0 @@ -from logging import getLogger -from typing import Optional - -import numpy as np -import torch -from PIL import Image - -from ..diffusers.load import load_pipeline -from ..params import ImageParams, StageParams -from ..server import ServerContext -from ..worker import ProgressCallback, WorkerContext - -logger = getLogger(__name__) - - -def blend_pix2pix( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - callback: Optional[ProgressCallback] = None, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - source = stage_source or source - logger.info( - "blending image using instruct pix2pix, %s steps: %s", - params.steps, - params.prompt, - ) - - pipe = load_pipeline( - server, - params, - "pix2pix", - job.get_device(), - # TODO: add LoRAs and TIs - ) - if params.lpw(): - logger.debug("using LPW pipeline for img2img") - rng = torch.manual_seed(params.seed) - result = pipe.img2img( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - strength=params.strength, - callback=callback, - ) - else: - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - strength=params.strength, - callback=callback, - ) - - output = result.images[0] - - logger.info("final output image size: %sx%s", output.width, output.height) - return output diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 436bca205..5e681f712 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -6,7 +6,7 @@ from PIL import Image from ..diffusers.load import load_pipeline -from ..diffusers.utils import get_latents_from_seed +from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt from ..params import ImageParams, Size, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -36,14 +36,17 @@ def source_txt2img( "a source image was passed to a txt2img stage, and will be discarded" ) + prompt_pairs, loras, inversions = parse_prompt(params) + latents = get_latents_from_seed(params.seed, size) - pipe_type = "lpw" if params.lpw() else "txt2img" + pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( server, params, pipe_type, job.get_device(), - # TODO: add LoRAs and TIs + inversions=inversions, + loras=loras, ) if params.lpw(): @@ -61,6 +64,10 @@ def source_txt2img( callback=callback, ) else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg()) + pipe.unet.set_prompts(prompt_embeds) + rng = np.random.RandomState(params.seed) result = pipe( params.prompt, diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py new file mode 100644 index 000000000..5eb1e3b58 --- /dev/null +++ b/api/onnx_web/chain/upscale_highres.py @@ -0,0 +1,108 @@ +from logging import getLogger +from typing import Any, Optional + +import numpy as np +import torch +from PIL import Image + +from ..diffusers.load import load_pipeline +from ..diffusers.upscale import append_upscale_correction +from ..diffusers.utils import parse_prompt +from ..params import HighresParams, ImageParams, Size, StageParams, UpscaleParams +from ..server import ServerContext +from ..worker import WorkerContext +from ..worker.context import ProgressCallback + +logger = getLogger(__name__) + + +def upscale_highres( + job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + highres: HighresParams, + upscale: UpscaleParams, + size: Size, + stage_source: Optional[Image.Image] = None, + pipeline: Optional[Any] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, +) -> Image.Image: + image = stage_source or source + + if highres.scale <= 1: + return image + + # load img2img pipeline once + pipe_type = params.get_valid_pipeline("img2img") + logger.debug("using %s pipeline for highres", pipe_type) + + _prompt_pairs, loras, inversions = parse_prompt(params) + highres_pipe = pipeline or load_pipeline( + server, + params, + pipe_type, + job.get_device(), + inversions=inversions, + loras=loras, + ) + + scaled_size = (source.width * highres.scale, source.height * highres.scale) + + if highres.method == "bilinear": + logger.debug("using bilinear interpolation for highres") + source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + elif highres.method == "lanczos": + logger.debug("using Lanczos interpolation for highres") + source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + else: + logger.debug("using upscaling pipeline for highres") + upscale = append_upscale_correction( + StageParams(), + params, + upscale=upscale.with_args( + faces=False, + scale=highres.scale, + outscale=highres.scale, + ), + ) + source = upscale( + job, + server, + source, + callback=callback, + ) + + if pipe_type == "lpw": + rng = torch.manual_seed(params.seed) + result = highres_pipe.img2img( + source, + params.prompt, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=highres.steps, + strength=highres.strength, + eta=params.eta, + callback=callback, + ) + return result.images[0] + else: + rng = np.random.RandomState(params.seed) + result = highres_pipe( + params.prompt, + source, + generator=rng, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_images_per_prompt=1, + num_inference_steps=highres.steps, + strength=highres.strength, + eta=params.eta, + callback=callback, + ) + return result.images[0] diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index fffb4a30f..b1c489284 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -1,12 +1,10 @@ from logging import getLogger -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional -import numpy as np -import torch from PIL import Image -from ..chain import blend_mask, upscale_outpaint -from ..chain.utils import process_tile_order +from ..chain import blend_img2img, blend_mask, upscale_highres, upscale_outpaint +from ..chain.base import ChainPipeline from ..output import save_image from ..params import ( Border, @@ -14,319 +12,73 @@ ImageParams, Size, StageParams, - TileOrder, UpscaleParams, ) from ..server import ServerContext from ..server.load import get_source_filters from ..utils import run_gc, show_system_toast from ..worker import WorkerContext -from ..worker.context import ProgressCallback -from .load import load_pipeline -from .upscale import run_upscale_correction -from .utils import encode_prompt, get_latents_from_seed, parse_prompt +from .upscale import append_upscale_correction, split_upscale +from .utils import parse_prompt logger = getLogger(__name__) -def run_loopback( - job: WorkerContext, - server: ServerContext, - params: ImageParams, - strength: float, - image: Image.Image, - progress: ProgressCallback, - inversions: List[Tuple[str, float]], - loras: List[Tuple[str, float]], - pipeline: Optional[Any] = None, -) -> Image.Image: - if params.loopback == 0: - return image - - # load img2img pipeline once - pipe_type = params.get_valid_pipeline("img2img") - if pipe_type == "controlnet": - logger.debug( - "controlnet pipeline cannot be used for loopback, switching to img2img" - ) - pipe_type = "img2img" - - logger.debug("using %s pipeline for loopback", pipe_type) - - pipe = pipeline or load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) - - def loopback_iteration(source: Image.Image): - if pipe_type == "lpw": - rng = torch.manual_seed(params.seed) - result = pipe.img2img( - source, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=params.steps, - strength=strength, - eta=params.eta, - callback=progress, - ) - return result.images[0] - else: - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - source, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=params.steps, - strength=strength, - eta=params.eta, - callback=progress, - ) - return result.images[0] - - for _i in range(params.loopback): - image = loopback_iteration(image) - - return image - - -def run_highres( +def run_txt2img_pipeline( job: WorkerContext, server: ServerContext, params: ImageParams, size: Size, + outputs: List[str], upscale: UpscaleParams, highres: HighresParams, - image: Image.Image, - progress: ProgressCallback, - inversions: List[Tuple[str, float]], - loras: List[Tuple[str, float]], - pipeline: Optional[Any] = None, -) -> Image.Image: - if highres.scale <= 1: - return image - - if upscale.faces and ( - upscale.upscale_order == "correction-both" - or upscale.upscale_order == "correction-first" - ): - image = run_upscale_correction( - job, - server, - StageParams(), +) -> None: + # prepare the chain pipeline and first stage + chain = ChainPipeline() + stage = StageParams() + chain.append((blend_img2img, stage, None)) + + # apply upscaling and correction, before highres + first_upscale, after_upscale = split_upscale(upscale) + if first_upscale: + append_upscale_correction( + stage, params, - image, - upscale=upscale.with_args( - scale=1, - outscale=1, - ), - callback=progress, + upscale=first_upscale, + chain=chain, ) - # load img2img pipeline once - pipe_type = params.get_valid_pipeline("img2img") - logger.debug("using %s pipeline for highres", pipe_type) + # apply highres + chain.append((upscale_highres, stage, None)) - highres_pipe = pipeline or load_pipeline( - server, + # apply upscaling and correction, after highres + append_upscale_correction( + StageParams(), params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) - - def highres_tile(tile: Image.Image, dims): - scaled_size = (tile.width * highres.scale, tile.height * highres.scale) - - if highres.method == "bilinear": - logger.debug("using bilinear interpolation for highres") - tile = tile.resize(scaled_size, resample=Image.Resampling.BILINEAR) - elif highres.method == "lanczos": - logger.debug("using Lanczos interpolation for highres") - tile = tile.resize(scaled_size, resample=Image.Resampling.LANCZOS) - else: - logger.debug("using upscaling pipeline for highres") - tile = run_upscale_correction( - job, - server, - StageParams(), - params, - tile, - upscale=upscale.with_args( - faces=False, - scale=highres.scale, - outscale=highres.scale, - ), - callback=progress, - ) - - if pipe_type == "lpw": - rng = torch.manual_seed(params.seed) - result = highres_pipe.img2img( - tile, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=progress, - ) - return result.images[0] - else: - rng = np.random.RandomState(params.seed) - result = highres_pipe( - params.prompt, - tile, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=progress, - ) - return result.images[0] - - logger.info( - "running highres fix for %s iterations at %s scale", - highres.iterations, - highres.scale, + upscale=upscale, + chain=chain, ) - for _i in range(highres.iterations): - image = process_tile_order( - TileOrder.grid, - image, - size.height // highres.scale, - highres.scale, - [highres_tile], - overlap=params.overlap, - ) - - return image - + # run and save + image = chain(job, server, params, None) -def run_txt2img_pipeline( - job: WorkerContext, - server: ServerContext, - params: ImageParams, - size: Size, - outputs: List[str], - upscale: UpscaleParams, - highres: HighresParams, -) -> None: - latents = get_latents_from_seed(params.seed, size, batch=params.batch) - prompt_pairs, loras, inversions = parse_prompt(params) - - pipe_type = params.get_valid_pipeline("txt2img") - logger.debug("using %s pipeline for txt2img", pipe_type) - - pipe = load_pipeline( + _prompt_pairs, loras, inversions = parse_prompt(params) + dest = save_image( server, + outputs[0], + image, params, - pipe_type, - job.get_device(), + size, + upscale=upscale, + highres=highres, inversions=inversions, loras=loras, ) - progress = job.get_progress_callback() - - if pipe_type == "lpw": - rng = torch.manual_seed(params.seed) - result = pipe.text2img( - params.prompt, - height=size.height, - width=size.width, - generator=rng, - guidance_scale=params.cfg, - latents=latents, - negative_prompt=params.negative_prompt, - num_images_per_prompt=params.batch, - num_inference_steps=params.steps, - eta=params.eta, - callback=progress, - ) - else: - # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, - prompt_pairs, - num_images_per_prompt=params.batch, - do_classifier_free_guidance=params.do_cfg(), - ) - pipe.unet.set_prompts(prompt_embeds) - - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - height=size.height, - width=size.width, - generator=rng, - guidance_scale=params.cfg, - latents=latents, - negative_prompt=params.negative_prompt, - num_images_per_prompt=params.batch, - num_inference_steps=params.steps, - eta=params.eta, - callback=progress, - ) - - image_outputs = list(zip(result.images, outputs)) - del result - del pipe - - for image, output in image_outputs: - image = run_highres( - job, - server, - params, - size, - upscale, - highres, - image, - progress, - inversions, - loras, - ) - - image = run_upscale_correction( - job, - server, - StageParams(), - params, - image, - upscale=upscale, - callback=progress, - ) - - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - highres=highres, - inversions=inversions, - loras=loras, - ) + # clean up run_gc([job.get_device()]) + + # notify the user show_system_toast(f"finished txt2img job: {dest}") logger.info("finished txt2img job: %s", dest) @@ -342,110 +94,75 @@ def run_img2img_pipeline( strength: float, source_filter: Optional[str] = None, ) -> None: - prompt_pairs, loras, inversions = parse_prompt(params) - - # filter the source image + # run filter on the source image if source_filter is not None: f = get_source_filters().get(source_filter, None) if f is not None: logger.debug("running source filter: %s", f.__name__) source = f(server, source) - pipe_type = params.get_valid_pipeline("img2img") - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, + # prepare the chain pipeline and first stage + chain = ChainPipeline() + stage = StageParams() + chain.append( + ( + blend_img2img, + stage, + { + "strength": strength, + }, + ) ) - pipe_params = {} - if pipe_type == "controlnet": - pipe_params["controlnet_conditioning_scale"] = strength - elif pipe_type == "img2img": - pipe_params["strength"] = strength - elif pipe_type == "panorama": - pipe_params["strength"] = strength - elif pipe_type == "pix2pix": - pipe_params["image_guidance_scale"] = strength - - progress = job.get_progress_callback() - if pipe_type == "lpw": - logger.debug("using LPW pipeline for img2img") - rng = torch.manual_seed(params.seed) - result = pipe.img2img( - source, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=params.batch, - num_inference_steps=params.steps, - eta=params.eta, - callback=progress, - **pipe_params, - ) - else: - # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg()) - pipe.unet.set_prompts(prompt_embeds) - - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - source, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=params.batch, - num_inference_steps=params.steps, - eta=params.eta, - callback=progress, - **pipe_params, + # apply upscaling and correction, before highres + first_upscale, after_upscale = split_upscale(upscale) + if first_upscale: + append_upscale_correction( + stage, + params, + upscale=first_upscale, + chain=chain, ) - images = result.images - if source_filter is not None and source_filter != "none": - images.append(source) + # loopback through multiple img2img iterations + if params.loopback > 0: + for _i in range(params.loopback): + chain.append( + ( + blend_img2img, + stage, + { + "strength": strength, + }, + ) + ) - for image, output in zip(images, outputs): - image = run_loopback( - job, - server, - params, - strength, - image, - progress, - inversions, - loras, - ) + # highres, if selected + if highres.iterations > 0: + for _i in range(highres.iterations): + chain.append((upscale_highres, stage, None)) - size = Size(*source.size) - image = run_highres( - job, - server, - params, - size, - upscale, - highres, - image, - progress, - inversions, - loras, - ) + # apply upscaling and correction, after highres + append_upscale_correction( + stage, + params, + upscale=after_upscale, + chain=chain, + ) - image = run_upscale_correction( - job, - server, - StageParams(), - params, - image, - upscale=upscale, - callback=progress, - ) + # run and append the filtered source + images = [ + chain(job, server, params, source), + ] + + if source_filter is not None and source_filter != "none": + images.append(source) + + # save with metadata + _prompt_pairs, loras, inversions = parse_prompt(params) + size = Size(*source.size) + for image, output in zip(images, outputs): dest = save_image( server, output, @@ -458,7 +175,10 @@ def run_img2img_pipeline( loras=loras, ) + # clean up run_gc([job.get_device()]) + + # notify the user show_system_toast(f"finished img2img job: {dest}") logger.info("finished img2img job: %s", dest) @@ -479,49 +199,48 @@ def run_inpaint_pipeline( fill_color: str, tile_order: str, ) -> None: - progress = job.get_progress_callback() - stage = StageParams(tile_order=tile_order) - - _prompt_pairs, loras, inversions = parse_prompt(params) + logger.debug("building inpaint pipeline") - logger.debug("applying mask filter and generating noise source") - image = upscale_outpaint( - job, - server, - stage, - params, - source, - border=border, - stage_mask=mask, - fill_color=fill_color, - mask_filter=mask_filter, - noise_source=noise_source, - callback=progress, + # set up the chain pipeline and base stage + chain = ChainPipeline() + stage = StageParams(tile_order=tile_order) + chain.append( + ( + upscale_outpaint, + stage, + { + "border": border, + "stage_mask": mask, + "fill_color": fill_color, + "mask_filter": mask_filter, + "noise_source": noise_source, + }, + ) ) - image = run_highres( - job, - server, - params, - size, - upscale, - highres, - image, - progress, - inversions, - loras, + # apply highres + chain.append( + ( + upscale_highres, + stage, + { + "highres": highres, + }, + ) ) - image = run_upscale_correction( - job, - server, + # apply upscaling and correction + append_upscale_correction( stage, params, - image, upscale=upscale, - callback=progress, + chain=chain, ) + # run and save + image = chain(job, server, params, source) + + _prompt_pairs, loras, inversions = parse_prompt(params) dest = save_image( server, outputs[0], @@ -534,9 +253,11 @@ def run_inpaint_pipeline( loras=loras, ) + # clean up del image run_gc([job.get_device()]) + # notify the user show_system_toast(f"finished inpaint job: {dest}") logger.info("finished inpaint job: %s", dest) @@ -551,34 +272,50 @@ def run_upscale_pipeline( highres: HighresParams, source: Image.Image, ) -> None: - progress = job.get_progress_callback() + # set up the chain pipeline, no base stage for upscaling + chain = ChainPipeline() stage = StageParams() - _prompt_pairs, loras, inversions = parse_prompt(params) + # apply upscaling and correction, before highres + first_upscale, after_upscale = split_upscale(upscale) + if first_upscale: + append_upscale_correction( + stage, + params, + upscale=first_upscale, + chain=chain, + ) - image = run_upscale_correction( - job, server, stage, params, source, upscale=upscale, callback=progress + # apply highres + chain.append((upscale_highres, stage, None)) + + # apply upscaling and correction, after highres + append_upscale_correction( + stage, + params, + upscale=after_upscale, + chain=chain, ) - # TODO: should this come first? - image = run_highres( - job, + # run and save + image = chain(job, server, params, source) + _prompt_pairs, loras, inversions = parse_prompt(params) + dest = save_image( server, + outputs[0], + image, params, size, - upscale, - highres, - image, - progress, - inversions, - loras, + upscale=upscale, + inversions=inversions, + loras=loras, ) - dest = save_image(server, outputs[0], image, params, size, upscale=upscale) - + # clean up del image run_gc([job.get_device()]) + # notify the user show_system_toast(f"finished upscale job: {dest}") logger.info("finished upscale job: %s", dest) @@ -594,28 +331,27 @@ def run_blend_pipeline( sources: List[Image.Image], mask: Image.Image, ) -> None: - progress = job.get_progress_callback() + # set up the chain pipeline and base stage + chain = ChainPipeline() stage = StageParams() + stage.append((blend_mask, stage, None)) - image = blend_mask( - job, - server, + # apply upscaling and correction + append_upscale_correction( stage, params, - sources=sources, - stage_mask=mask, - callback=progress, - ) - image = image.convert("RGB") - - image = run_upscale_correction( - job, server, stage, params, image, upscale=upscale, callback=progress + upscale=upscale, + chain=chain, ) + # run and save + image = chain(job, server, params, sources[0]) dest = save_image(server, outputs[0], image, params, size, upscale=upscale) + # clean up del image run_gc([job.get_device()]) + # notify the user show_system_toast(f"finished blend job: {dest}") logger.info("finished blend job: %s", dest) diff --git a/api/onnx_web/diffusers/upscale.py b/api/onnx_web/diffusers/upscale.py index c7d3601ab..064b0f4b6 100644 --- a/api/onnx_web/diffusers/upscale.py +++ b/api/onnx_web/diffusers/upscale.py @@ -1,7 +1,5 @@ from logging import getLogger -from typing import List, Optional - -from PIL import Image +from typing import List, Optional, Tuple from ..chain import ( ChainPipeline, @@ -14,24 +12,42 @@ upscale_swinir, ) from ..params import ImageParams, SizeChart, StageParams, UpscaleParams -from ..server import ServerContext -from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def run_upscale_correction( - job: WorkerContext, - server: ServerContext, +def split_upscale( + upscale: UpscaleParams, +) -> Tuple[Optional[UpscaleParams], UpscaleParams]: + if upscale.faces and ( + upscale.upscale_order == "correction-both" + or upscale.upscale_order == "correction-first" + ): + return ( + upscale.with_args( + scale=1, + outscale=1, + ), + upscale.with_args( + upscale_order="correction-last", + ), + ) + else: + return ( + None, + upscale, + ) + + +def append_upscale_correction( stage: StageParams, params: ImageParams, - image: Image.Image, *, upscale: UpscaleParams, - callback: Optional[ProgressCallback] = None, + chain: Optional[ChainPipeline] = None, pre_stages: List[PipelineStage] = None, post_stages: List[PipelineStage] = None, -) -> Image.Image: +) -> ChainPipeline: """ This is a convenience method for a chain pipeline that will run upscaling and correction, based on the `upscale` params. @@ -42,7 +58,9 @@ def run_upscale_correction( upscale.outscale, ) - chain = ChainPipeline() + if chain is None: + chain = ChainPipeline() + if pre_stages is not None: for stage, params in pre_stages: chain.append((stage, params)) @@ -103,12 +121,4 @@ def run_upscale_correction( for stage, params in post_stages: chain.append((stage, params)) - return chain( - job, - server, - params, - image, - prompt=params.prompt, - upscale=upscale, - callback=callback, - ) + return chain