Skip to content

Commit

Permalink
fix(api): generate new latents for partial tiles
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jul 13, 2023
1 parent 4b8358b commit 3d4c77d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/source_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def run(
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
latents = get_tile_latents(latents, dims, latent_size)
latents = get_tile_latents(latents, params.seed, latent_size, dims)

pipe_type = params.get_valid_pipeline("txt2img")
pipe = load_pipeline(
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_outpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run(
if latents is None:
latents = get_latents_from_seed(params.seed, latent_size, params.batch)
else:
latents = get_tile_latents(latents, dims, latent_size)
latents = get_tile_latents(latents, params.seed, latent_size, dims)

if params.lpw():
logger.debug("using LPW pipeline for inpaint")
Expand Down
12 changes: 5 additions & 7 deletions api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:

def get_tile_latents(
full_latents: np.ndarray,
dims: Tuple[int, int, int],
seed: int,
size: Size,
dims: Tuple[int, int, int],
) -> np.ndarray:
x, y, tile = dims
t = tile // LATENT_FACTOR
Expand All @@ -284,12 +285,9 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt]

if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
px = mx - tile_latents.shape[3]
py = my - tile_latents.shape[2]

tile_latents = np.pad(
tile_latents, ((0, 0), (0, 0), (0, py), (0, px)), mode="reflect"
)
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[:, :, 0:t, 0:t] = tile_latents
tile_latents = extra_latents

return tile_latents

Expand Down

0 comments on commit 3d4c77d

Please sign in to comment.