diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py
index 8e7a64877..6aa360710 100644
--- a/api/onnx_web/chain/base.py
+++ b/api/onnx_web/chain/base.py
@@ -9,7 +9,7 @@
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
-from .utils import process_tile_grid
+from .utils import process_tile_order
logger = getLogger(__name__)
@@ -100,8 +100,12 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:
return tile
- image = process_tile_grid(
- image, stage_params.tile_size, stage_params.outscale, [stage_tile]
+ image = process_tile_order(
+ stage_params.tile_order,
+ image,
+ stage_params.tile_size,
+ stage_params.outscale,
+ [stage_tile],
)
else:
logger.info("image within tile size, running stage")
diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py
index a980e5a53..350bdfe0b 100644
--- a/api/onnx_web/chain/blend_inpaint.py
+++ b/api/onnx_web/chain/blend_inpaint.py
@@ -12,7 +12,7 @@
from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug
-from .utils import process_tile_grid
+from .utils import process_tile_order
logger = getLogger(__name__)
@@ -101,7 +101,9 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
return result.images[0]
- output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
+ output = process_tile_order(
+ stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
+ )
logger.info("final output image size", output.size)
return output
diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py
index ffd55b026..2f64f62c3 100644
--- a/api/onnx_web/chain/upscale_outpaint.py
+++ b/api/onnx_web/chain/upscale_outpaint.py
@@ -10,9 +10,9 @@
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image
-from ..params import Border, ImageParams, Size, SizeChart, StageParams
+from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder
from ..utils import ServerContext, is_debug
-from .utils import process_tile_grid, process_tile_spiral
+from .utils import process_tile_grid, process_tile_order
logger = getLogger(__name__)
@@ -120,8 +120,13 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
"outpainting with an even border, using spiral tiling with %s overlap",
overlap,
)
- output = process_tile_spiral(
- source_image, SizeChart.auto, 1, [outpaint], overlap=overlap
+ output = process_tile_order(
+ stage.tile_order,
+ source_image,
+ SizeChart.auto,
+ 1,
+ [outpaint],
+ overlap=overlap,
)
else:
logger.debug("outpainting with an uneven border, using grid tiling")
diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py
index 1bb3ac220..053a093fa 100644
--- a/api/onnx_web/chain/utils.py
+++ b/api/onnx_web/chain/utils.py
@@ -3,6 +3,8 @@
from PIL import Image
+from ..params import TileOrder
+
logger = getLogger(__name__)
@@ -16,6 +18,7 @@ def process_tile_grid(
tile: int,
scale: int,
filters: List[TileCallback],
+ **kwargs,
) -> Image.Image:
width, height = source.size
image = Image.new("RGB", (width * scale, height * scale))
@@ -46,6 +49,7 @@ def process_tile_spiral(
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
+ **kwargs,
) -> Image.Image:
if scale != 1:
raise Exception("unsupported scale")
@@ -87,3 +91,22 @@ def process_tile_spiral(
image.paste(tile_image, (left * scale, top * scale))
return image
+
+
+def process_tile_order(
+ order: TileOrder,
+ source: Image.Image,
+ tile: int,
+ scale: int,
+ filters: List[TileCallback],
+ **kwargs,
+) -> Image.Image:
+ if order == TileOrder.grid:
+ logger.debug("using grid tile order with tile size: %s", tile)
+ return process_tile_grid(source, tile, scale, filters, **kwargs)
+ elif order == TileOrder.kernel:
+ logger.debug("using kernel tile order with tile size: %s", tile)
+ raise NotImplementedError()
+ elif order == TileOrder.spiral:
+ logger.debug("using spiral tile order with tile size: %s", tile)
+ return process_tile_spiral(source, tile, scale, filters, **kwargs)
diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py
index e8924b946..1c96bfd76 100644
--- a/api/onnx_web/diffusion/run.py
+++ b/api/onnx_web/diffusion/run.py
@@ -151,10 +151,11 @@ def run_inpaint_pipeline(
mask_filter: Any,
strength: float,
fill_color: str,
+ tile_order: str,
) -> None:
# device = job.get_device()
# progress = job.get_progress_callback()
- stage = StageParams()
+ stage = StageParams(tile_order=tile_order)
image = upscale_outpaint(
job,
diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py
index f81e61c96..a6144c245 100644
--- a/api/onnx_web/params.py
+++ b/api/onnx_web/params.py
@@ -14,6 +14,12 @@ class SizeChart(IntEnum):
hd64k = 2**16
+class TileOrder:
+ grid = "grid"
+ kernel = "kernel"
+ spiral = "spiral"
+
+
Param = Union[str, int, float]
Point = Tuple[int, int]
@@ -122,13 +128,15 @@ class StageParams:
def __init__(
self,
name: Optional[str] = None,
- tile_size: int = SizeChart.auto,
outscale: int = 1,
+ tile_order: str = TileOrder.grid,
+ tile_size: int = SizeChart.auto,
# batch_size: int = 1,
) -> None:
self.name = name
- self.tile_size = tile_size
self.outscale = outscale
+ self.tile_order = tile_order
+ self.tile_size = tile_size
class UpscaleParams:
diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py
index 67e049e61..79fdc57b8 100644
--- a/api/onnx_web/serve.py
+++ b/api/onnx_web/serve.py
@@ -64,7 +64,7 @@
noise_source_uniform,
)
from .output import json_params, make_output_name
-from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
+from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder
from .utils import (
ServerContext,
base_join,
@@ -589,6 +589,7 @@ def inpaint():
get_config_value("strength", "max"),
get_config_value("strength", "min"),
)
+ tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral])
output = make_output_name(
context,
@@ -604,6 +605,7 @@ def inpaint():
noise_source.__name__,
strength,
fill_color,
+ tile_order,
),
)
logger.info("inpaint job queued for: %s", output)
@@ -625,6 +627,7 @@ def inpaint():
mask_filter,
strength,
fill_color,
+ tile_order,
needs_device=device,
)
diff --git a/api/params.json b/api/params.json
index 6f194a0ee..f1a9824dc 100644
--- a/api/params.json
+++ b/api/params.json
@@ -66,10 +66,6 @@
"default": "histogram",
"keys": []
},
- "order": {
- "default": "spiral",
- "keys": []
- },
"outscale": {
"default": 1,
"min": 1,
@@ -118,6 +114,10 @@
"max": 1,
"step": 0.01
},
+ "tileOrder": {
+ "default": "spiral",
+ "keys": []
+ },
"top": {
"default": 0,
"min": 0,
diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx
index bfd9983f6..2098857d0 100644
--- a/gui/src/components/tab/Inpaint.tsx
+++ b/gui/src/components/tab/Inpaint.tsx
@@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
-import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material';
+import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query';
@@ -161,7 +161,7 @@ export function Inpaint() {
tileOrder: e.target.value,
});
}}
- >
+ >{['grid', 'kernel', 'spiral'].map((name) => )}