Skip to content

Commit

Permalink
fix(api): preserve new pixels after outpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 29, 2023
1 parent 20beff8 commit 7083505
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 9 additions & 4 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
OnnxStableDiffusionInpaintPipeline,
)
from logging import getLogger
from PIL import Image
from PIL import Image, ImageDraw
from typing import Callable, Tuple

from ..diffusion.load import (
Expand Down Expand Up @@ -41,7 +41,7 @@ def upscale_outpaint(
params: ImageParams,
source_image: Image.Image,
*,
expand: Border,
border: Border,
prompt: str = None,
mask_image: Image.Image = None,
fill_color: str = 'white',
Expand All @@ -50,7 +50,7 @@ def upscale_outpaint(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info('upscaling image by expanding borders: %s', expand)
logger.info('upscaling image by expanding borders: %s', border)

if mask_image is None:
# if no mask was provided, keep the full source image
Expand All @@ -59,11 +59,13 @@ def upscale_outpaint(
source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
mask_image,
expand,
border,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)

draw_mask = ImageDraw.Draw(mask_image)

if is_debug():
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
Expand Down Expand Up @@ -98,6 +100,9 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
num_inference_steps=params.steps,
width=size.width,
)

# once part of the image has been drawn, keep it
draw_mask.rectangle((left, top, left + tile, top + tile), fill='black')
return result.images[0]

output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def process_tile_spiral(
top = center_y + int(top)

counter += 1
logger.info('processing tile %s of %s', counter, len(tiles))
logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top)

# TODO: only valid for scale == 1, resize source for others
tile_image = image.crop((left, top, left + tile, top + tile))
Expand Down

0 comments on commit 7083505

Please sign in to comment.