Skip to content

Commit

Permalink
fix(api): resize images to min dimensions by padding if necessary (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 18, 2023
1 parent 3dde3b9 commit 0e108da
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 13 deletions.
8 changes: 4 additions & 4 deletions api/onnx_web/chain/blend_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

from PIL import Image
from onnx_web.image import valid_image

from onnx_web.output import save_image

Expand All @@ -18,7 +19,7 @@ def blend_mask(
_stage: StageParams,
_params: ImageParams,
*,
sources: Optional[List[Image.Image]] = None,
resized: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
**kwargs,
Expand All @@ -33,7 +34,6 @@ def blend_mask(
save_image(server, "last-mask.png", mask)
save_image(server, "last-mult-mask.png", mult_mask)

for source in sources:
source.thumbnail(mult_mask.size)
resized = [valid_image(s, min_dims=mult_mask.size, max_dims=mult_mask.size) for s in resized]

return Image.composite(sources[0], sources[1], mult_mask)
return Image.composite(resized[0], resized[1], mult_mask)
2 changes: 1 addition & 1 deletion api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def run_blend_pipeline(
server,
stage,
params,
sources=sources,
resized=sources,
mask=mask,
callback=progress,
)
Expand Down
24 changes: 23 additions & 1 deletion api/onnx_web/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from numpy import random
from PIL import Image, ImageChops, ImageFilter
from PIL import Image, ImageChops, ImageFilter, ImageOps
from typing import Tuple

from .params import Border, Point

Expand Down Expand Up @@ -189,3 +190,24 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))

return (full_source, full_mask, full_noise, (full_width, full_height))


def valid_image(
image: Image.Image,
min_dims: Tuple[int, int] = [512, 512],
max_dims: Tuple[int, int] = [512, 512],
) -> Image.Image:
min_x, min_y = min_dims
max_x, max_y = max_dims

if image.width > max_x or image.height > max_y:
image = ImageOps.contain(image, max_dims)

if image.width < min_x or image.height < min_y:
blank = Image.new(image.mode, min_dims, "black")
blank.paste(image)
image = blank

# check for square

return image
18 changes: 11 additions & 7 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
noise_source_histogram,
noise_source_normal,
noise_source_uniform,
valid_image,
)
from .output import json_params, make_output_name
from .params import (
Expand Down Expand Up @@ -508,7 +509,7 @@ def img2img():
output = make_output_name(context, "img2img", params, size, extras=(strength,))
logger.info("img2img job queued for: %s", output)

source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_img2img_pipeline,
Expand Down Expand Up @@ -597,8 +598,8 @@ def inpaint():
)
logger.info("inpaint job queued for: %s", output)

source_image.thumbnail((size.width, size.height))
mask_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
mask_image = valid_image(mask_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_inpaint_pipeline,
Expand Down Expand Up @@ -635,7 +636,7 @@ def upscale():
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)

source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, min_dims=size, max_dims=size)
executor.submit(
output,
run_upscale_pipeline,
Expand Down Expand Up @@ -702,7 +703,7 @@ def chain():
)
source_file = request.files.get(stage_source_name)
source_image = Image.open(BytesIO(source_file.read())).convert("RGB")
source_image.thumbnail((size.width, size.height))
source_image = valid_image(source_image, max_dims=(size.width, size.height))
kwargs["source_image"] = source_image

if stage_mask_name in request.files:
Expand All @@ -713,7 +714,7 @@ def chain():
)
mask_file = request.files.get(stage_mask_name)
mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB")
mask_image.thumbnail((size.width, size.height))
mask_image = valid_image(mask_image, max_dims=(size.width, size.height))
kwargs["mask_image"] = mask_image

pipeline.append((callback, stage, kwargs))
Expand Down Expand Up @@ -743,13 +744,16 @@ def blend():

mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
mask = valid_image(mask)

max_sources = 2
sources = []

for i in range(max_sources):
source_file = request.files.get("source:%s" % (i))
sources.append(Image.open(BytesIO(source_file.read())).convert("RGBA"))
source = Image.open(BytesIO(source_file.read())).convert("RGBA")
source = valid_image(source, mask.size, mask.size)
sources.append(source)

device, params, size = pipeline_from_request()
upscale = upscale_from_request()
Expand Down

0 comments on commit 0e108da

Please sign in to comment.