diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 253a881c8..602a7438f 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -11,7 +11,7 @@ from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..utils import ServerContext, is_debug -from .utils import process_tile_spiral +from .utils import process_tile_grid, process_tile_spiral logger = getLogger(__name__) @@ -92,7 +92,16 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]): draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") return result.images[0] - output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) + margin_x = float(max(border.left, border.right)) + margin_y = float(max(border.top, border.bottom)) + overlap = min(margin_x / source_image.width, margin_y / source_image.height) + + if overlap > 0 and border.left == border.right and border.top == border.bottom: + logger.debug("outpainting with an even border, using spiral tiling") + output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint], overlap=overlap) + else: + logger.debug("outpainting with an uneven border, using grid tiling") + output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 0e112303d..1bb3ac220 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -57,18 +57,17 @@ def process_tile_spiral( center_x = (width // 2) - (tile // 2) center_y = (height // 2) - (tile // 2) - # TODO: only valid for overlap = 0.5 - if overlap == 0.5: - tiles = [ - (0, tile * -overlap), - (tile * overlap, tile * -overlap), - (tile * overlap, 0), - (tile * overlap, tile * overlap), - (0, tile * overlap), - (tile * -overlap, tile * overlap), - (tile * -overlap, 0), - (tile * -overlap, tile * -overlap), - ] + # TODO: should add/remove tiles when overlap != 0.5 + tiles = [ + (0, tile * -overlap), + (tile * overlap, tile * -overlap), + (tile * overlap, 0), + (tile * overlap, tile * overlap), + (0, tile * overlap), + (tile * -overlap, tile * overlap), + (tile * -overlap, 0), + (tile * -overlap, tile * -overlap), + ] # tile tuples is source, multiply by scale for dest counter = 0