Skip to content

Commit

Permalink
fix(api): make panorama work with prompt alternatives
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Apr 27, 2023
1 parent b8b73d8 commit c6fc860
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
19 changes: 16 additions & 3 deletions api/onnx_web/diffusers/pipelines/panorama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
from logging import getLogger
from typing import Callable, List, Optional, Union

import PIL
Expand Down Expand Up @@ -299,7 +300,7 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
views.append((h_start, h_end, w_start, w_end))
return views

def __call__(
def text2img(
self,
prompt: Union[str, List[str]] = None,
height: Optional[int] = 512,
Expand Down Expand Up @@ -635,11 +636,11 @@ def img2img(
# prep image
image = preprocess(image).cpu().numpy()
image = image.astype(latents_dtype)

# encode the init image into latents and scale the latents
latents = self.vae_encoder(sample=image)[0]
latents = 0.18215 * latents

latents = latents * np.float64(self.scheduler.init_noise_sigma)
# latents = latents * np.float64(self.scheduler.init_noise_sigma)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
Expand Down Expand Up @@ -746,3 +747,15 @@ def img2img(
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

def __call__(
self,
*args,
**kwargs,
):
if len(args) > 0 and (isinstance(args[0], np.ndarray) or isinstance(args[0], PIL.Image.Image)):
logger.debug("running img2img panorama pipeline")
return self.img2img(*args, **kwargs)
else:
logger.debug("running txt2img panorama pipeline")
return self.text2img(*args, **kwargs)
8 changes: 4 additions & 4 deletions api/onnx_web/diffusers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def run_loopback(
)

def loopback_iteration(source: Image.Image):
if pipe_type in ["lpw", "panorama"]:
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
Expand Down Expand Up @@ -174,7 +174,7 @@ def highres_tile(tile: Image.Image, dims):
callback=highres_progress,
)

if pipe_type in ["lpw", "panorama"]:
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
tile,
Expand Down Expand Up @@ -250,7 +250,7 @@ def run_txt2img_pipeline(
)
progress = job.get_progress_callback()

if pipe_type in ["lpw", "panorama"]:
if pipe_type == "lpw":
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
Expand Down Expand Up @@ -369,7 +369,7 @@ def run_img2img_pipeline(
pipe_params["image_guidance_scale"] = strength

progress = job.get_progress_callback()
if pipe_type in ["lpw", "panorama"]:
if pipe_type == "lpw":
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
Expand Down

0 comments on commit c6fc860

Please sign in to comment.