diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 18fe54638..500559dfc 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +from logging import getLogger from typing import Callable, List, Optional, Union import PIL @@ -299,7 +300,7 @@ def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): views.append((h_start, h_end, w_start, w_end)) return views - def __call__( + def text2img( self, prompt: Union[str, List[str]] = None, height: Optional[int] = 512, @@ -635,11 +636,11 @@ def img2img( # prep image image = preprocess(image).cpu().numpy() image = image.astype(latents_dtype) + # encode the init image into latents and scale the latents latents = self.vae_encoder(sample=image)[0] latents = 0.18215 * latents - - latents = latents * np.float64(self.scheduler.init_noise_sigma) + # latents = latents * np.float64(self.scheduler.init_noise_sigma) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) @@ -746,3 +747,15 @@ def img2img( return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + def __call__( + self, + *args, + **kwargs, + ): + if len(args) > 0 and (isinstance(args[0], np.ndarray) or isinstance(args[0], PIL.Image.Image)): + logger.debug("running img2img panorama pipeline") + return self.img2img(*args, **kwargs) + else: + logger.debug("running txt2img panorama pipeline") + return self.text2img(*args, **kwargs) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 3b070a0b5..72fc85ee3 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -61,7 +61,7 @@ def run_loopback( ) def loopback_iteration(source: Image.Image): - if pipe_type in ["lpw", "panorama"]: + if pipe_type == "lpw": rng = torch.manual_seed(params.seed) result = pipe.img2img( source, @@ -174,7 +174,7 @@ def highres_tile(tile: Image.Image, dims): callback=highres_progress, ) - if pipe_type in ["lpw", "panorama"]: + if pipe_type == "lpw": rng = torch.manual_seed(params.seed) result = highres_pipe.img2img( tile, @@ -250,7 +250,7 @@ def run_txt2img_pipeline( ) progress = job.get_progress_callback() - if pipe_type in ["lpw", "panorama"]: + if pipe_type == "lpw": rng = torch.manual_seed(params.seed) result = pipe.text2img( params.prompt, @@ -369,7 +369,7 @@ def run_img2img_pipeline( pipe_params["image_guidance_scale"] = strength progress = job.get_progress_callback() - if pipe_type in ["lpw", "panorama"]: + if pipe_type == "lpw": logger.debug("using LPW pipeline for img2img") rng = torch.manual_seed(params.seed) result = pipe.img2img(