Skip to content

Commit

Permalink
fix(api): complete panorama tiles for SD pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Dec 3, 2023
1 parent 103d1a4 commit b54a57b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
26 changes: 21 additions & 5 deletions api/onnx_web/diffusers/pipelines/panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer

from onnx_web.chain.tile import make_tile_mask

from ..utils import LATENT_CHANNELS, LATENT_FACTOR, parse_regions, repair_nan
from ...chain.tile import make_tile_mask
from ...params import Size
from ..utils import (
LATENT_CHANNELS,
LATENT_FACTOR,
expand_latents,
parse_regions,
repair_nan,
)

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -373,7 +379,7 @@ def get_views(self, panorama_height, panorama_width, window_size, stride):
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))

return views
return (views, (h_end, w_end))

@torch.no_grad()
def text2img(
Expand Down Expand Up @@ -552,10 +558,17 @@ def text2img(
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]

# panorama additions
views = self.get_views(height, width, self.window, self.stride)
views, resize = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)

latents = expand_latents(
latents,
generator.randint(),
Size(width, height),
sigma=self.scheduler.init_noise_sigma,
)

for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0)
Expand Down Expand Up @@ -707,6 +720,9 @@ def text2img(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

# remove extra margins
latents = latents[:, :, 0:height, 0:width]

latents = np.clip(latents, -4, +4)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
Expand Down
7 changes: 6 additions & 1 deletion api/onnx_web/diffusers/pipelines/panorama_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def text2img(
value = np.zeros_like((latents[0], latents[1], *resize))

# adjust latents
latents = expand_latents(latents, generator.randint(), Size(width, height))
latents = expand_latents(
latents,
generator.randint(),
Size(width, height),
sigma=self.scheduler.init_noise_sigma,
)

# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
Expand Down
3 changes: 2 additions & 1 deletion api/onnx_web/diffusers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,12 @@ def expand_latents(
latents: np.ndarray,
seed: int,
size: Size,
sigma: float = 1.0,
) -> np.ndarray:
batch, _channels, height, width = latents.shape
extra_latents = get_latents_from_seed(seed, size, batch=batch)
extra_latents[:, :, 0:height, 0:width] = latents
return extra_latents
return extra_latents * np.float64(sigma)


def get_tile_latents(
Expand Down

0 comments on commit b54a57b

Please sign in to comment.