From 3ffbc00390d95a4a03096807b168e6b96fd2c14b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 12 Nov 2023 14:23:02 -0600 Subject: [PATCH] fix(api): turn alternatives back off for SDXL --- api/onnx_web/chain/blend_img2img.py | 9 +++++---- api/onnx_web/chain/source_txt2img.py | 9 +++++---- api/onnx_web/chain/upscale_outpaint.py | 9 +++++---- api/onnx_web/chain/upscale_stable_diffusion.py | 15 ++++++++------- 4 files changed, 23 insertions(+), 19 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 89f1301b9..274ab4071 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -81,10 +81,11 @@ def run( ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 3840fdd21..ce1f04dfe 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -130,10 +130,11 @@ def run( ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 85ddc0791..cdc3a0677 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -99,10 +99,11 @@ def run( ) else: # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) rng = np.random.RandomState(params.seed) result = pipe( diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 9d5a7b323..cf784b053 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -48,13 +48,14 @@ def run( ) generator = torch.manual_seed(params.seed) - prompt_embeds = encode_prompt( - pipeline, - prompt_pairs, - num_images_per_prompt=params.batch, - do_classifier_free_guidance=params.do_cfg(), - ) - pipeline.unet.set_prompts(prompt_embeds) + if not params.is_xl(): + prompt_embeds = encode_prompt( + pipeline, + prompt_pairs, + num_images_per_prompt=params.batch, + do_classifier_free_guidance=params.do_cfg(), + ) + pipeline.unet.set_prompts(prompt_embeds) outputs = [] for source in sources: