Skip to content

Commit

Permalink
feat(api): implement spiral grid for outpainting
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 29, 2023
1 parent 680adc7 commit a4d3f18
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 15 deletions.
4 changes: 2 additions & 2 deletions api/onnx_web/chain/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ServerContext,
)
from .utils import (
process_tiles,
process_tile_grid,
)

logger = getLogger(__name__)
Expand Down Expand Up @@ -86,7 +86,7 @@ def stage_tile(tile: Image.Image, _dims) -> Image.Image:

return tile

image = process_tiles(
image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('source image within tile size, running stage')
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/chain/blend_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ServerContext,
)
from .utils import (
process_tiles,
process_tile_grid,
)

import numpy as np
Expand Down Expand Up @@ -98,7 +98,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
)
return result.images[0]

output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])

logger.info('final output image size', output.size)
return output
4 changes: 2 additions & 2 deletions api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ServerContext,
)
from .utils import (
process_tiles,
process_tile_spiral,
)

import numpy as np
Expand Down Expand Up @@ -98,7 +98,7 @@ def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
)
return result.images[0]

output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])

logger.info('final output image size: %sx%s', output.width, output.height)
return output
46 changes: 45 additions & 1 deletion api/onnx_web/chain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Imag
pass


def process_tiles(
def process_tile_grid(
source: Image.Image,
tile: int,
scale: int,
Expand All @@ -37,3 +37,47 @@ def process_tiles(
image.paste(tile_image, (left * scale, top * scale))

return image


def process_tile_spiral(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
) -> Image.Image:
if scale != 1:
raise Exception('unsupported scale')

width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
image.paste(source, (0, 0))

# TODO: only valid for overlap = 0.5
if overlap == 0.5:
tiles = [
(0, tile * -overlap),
(tile * overlap, tile * -overlap),
(tile * overlap, 0),
(tile * overlap, tile * overlap),
(0, tile * overlap),
(tile * -overlap, tile * -overlap),
(tile * -overlap, 0),
(tile * -overlap, tile * overlap),
]

# tile tuples is source, multiply by scale for dest
counter = 0
for left, top in tiles:
logger.info('processing tile %s of %s', counter, len(tiles))
counter += 1

# TODO: only valid for scale == 1, resize source for others
tile_image = image.crop((left, top, left + tile, top + tile))

for filter in filters:
tile_image = filter(tile_image, (left, top, tile))

image.paste(tile_image, (left * scale, top * scale))

return image
13 changes: 7 additions & 6 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
if val is None:
return SizeChart.auto

if type(val) is str:
if val in SizeChart:
return SizeChart[val]
else:
return int(val)

if type(val) is int:
return val

if type(val) is str:
for size in SizeChart:
if val == size.name:
return size

return int(val)

raise Exception('invalid size')


Expand Down
4 changes: 2 additions & 2 deletions common/pipelines/example.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"name": "save-local",
"type": "persist-disk",
"params": {
"tile_size": "8k"
"tile_size": "hd8k"
}
},
{
Expand All @@ -40,7 +40,7 @@
"bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph",
"tile_size": "8k"
"tile_size": "hd8k"
}
}
]
Expand Down

0 comments on commit a4d3f18

Please sign in to comment.