Skip to content

Commit

Permalink
feat(api): remove size restrictions on most pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 1, 2023
1 parent 934dabb commit 5e1b700
Show file tree
Hide file tree
Showing 7 changed files with 3 additions and 44 deletions.
1 change: 0 additions & 1 deletion api/onnx_web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from .diffusers.upscale import stage_upscale_correction
from .image.utils import (
expand_image,
valid_image,
)
from .image.mask_filter import (
mask_filter_gaussian_multiply,
Expand Down
5 changes: 1 addition & 4 deletions api/onnx_web/chain/blend_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from PIL import Image

from ..image import valid_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
Expand All @@ -24,6 +23,4 @@ def blend_linear(
) -> Image.Image:
logger.info("blending image using linear interpolation")

resized = [valid_image(s) for s in sources]

return Image.blend(resized[1], resized[0], alpha)
return Image.blend(sources[1], sources[0], alpha)
5 changes: 2 additions & 3 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from logging import getLogger
from typing import List, Optional
from typing import Optional

from PIL import Image

from ..image import valid_image
from ..output import save_image
from ..params import ImageParams, StageParams
from ..server import ServerContext
Expand Down Expand Up @@ -35,4 +34,4 @@ def blend_mask(
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)

return Image.composite(source, stage_source, mult_mask)
return Image.composite(stage_source, source, mult_mask)
1 change: 0 additions & 1 deletion api/onnx_web/chain/reduce_thumbnail.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def reduce_thumbnail(
source = stage_source or source
image = source.copy()

# TODO: should use a call to valid_image
image = image.thumbnail((size.width, size.height))

logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
Expand Down
1 change: 0 additions & 1 deletion api/onnx_web/image/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .utils import (
expand_image,
valid_image,
)
from .mask_filter import (
mask_filter_gaussian_multiply,
Expand Down
21 changes: 0 additions & 21 deletions api/onnx_web/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,3 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))

return (full_source, full_mask, full_noise, size)


def valid_image(
image: Image.Image,
min_dims: Union[Size, Tuple[int, int]] = [512, 512],
max_dims: Union[Size, Tuple[int, int]] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims

if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, (max_x, max_y))

if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, (min_x, min_y), "black")
blank.paste(image)
image = blank

# check for square

return image
13 changes: 0 additions & 13 deletions api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
run_txt2img_pipeline,
run_upscale_pipeline,
)
from ..image import valid_image # mask filters; noise sources
from ..output import json_params, make_output_name
from ..params import Border, StageParams, TileOrder, UpscaleParams
from ..transformers.run import run_txt2txt_pipeline
Expand Down Expand Up @@ -189,12 +188,6 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
server, "img2img", params, size, extras=[strength], count=output_count
)

if params.get_valid_pipeline("img2img") != "panorama":
logger.info(
"resizing input image for limited pipeline, use panorama pipeline for full-size"
)
source = valid_image(source, min_dims=size, max_dims=size)

job_name = output[0]
pool.submit(
job_name,
Expand Down Expand Up @@ -323,9 +316,6 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):

output = make_output_name(server, "upscale", params, size)

logger.info("resizing source image for limited pipeline")
source = valid_image(source, min_dims=size, max_dims=size)

job_name = output[0]
pool.submit(
job_name,
Expand Down Expand Up @@ -396,7 +386,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
source_file = request.files.get(stage_source_name)
if source_file is not None:
source = Image.open(BytesIO(source_file.read())).convert("RGB")
source = valid_image(source, max_dims=(size.width, size.height))
kwargs["stage_source"] = source

if stage_mask_name in request.files:
Expand All @@ -408,7 +397,6 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
mask_file = request.files.get(stage_mask_name)
if mask_file is not None:
mask = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask = valid_image(mask, max_dims=(size.width, size.height))
kwargs["stage_mask"] = mask

pipeline.append((callback, stage, kwargs))
Expand Down Expand Up @@ -447,7 +435,6 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.warning("missing source %s", i)
else:
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)

device, params, size = pipeline_from_request(server)
Expand Down

0 comments on commit 5e1b700

Please sign in to comment.