-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): add support for extremely long prompts
- Loading branch information
Showing
3 changed files
with
125 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,115 @@ | ||
def expand_prompt(prompt: str) -> str: | ||
return prompt | ||
from logging import getLogger | ||
from math import ceil | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
from diffusers import OnnxStableDiffusionPipeline | ||
|
||
logger = getLogger(__name__) | ||
|
||
MAX_TOKENS_PER_GROUP = 77 | ||
|
||
|
||
def expand_prompt( | ||
self: OnnxStableDiffusionPipeline, | ||
prompt: str, | ||
num_images_per_prompt: int, | ||
do_classifier_free_guidance: bool, | ||
negative_prompt: Optional[str] = None, | ||
) -> "np.NDArray": | ||
# self provides: | ||
# tokenizer: CLIPTokenizer | ||
# encoder: OnnxRuntimeModel | ||
|
||
batch_size = len(prompt) if isinstance(prompt, list) else 1 | ||
|
||
# split prompt into 75 token chunks | ||
tokens = self.tokenizer( | ||
prompt, | ||
padding="max_length", | ||
return_tensors="np", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=False, | ||
) | ||
|
||
groups_count = ceil(tokens.input_ids.shape[1] / MAX_TOKENS_PER_GROUP) | ||
logger.info("splitting %s into %s groups", tokens.input_ids.shape, groups_count) | ||
|
||
groups = [] | ||
# np.array_split(tokens.input_ids, groups_count, axis=1) | ||
for i in range(groups_count): | ||
group_start = i * MAX_TOKENS_PER_GROUP | ||
group_end = min( | ||
group_start + MAX_TOKENS_PER_GROUP, tokens.input_ids.shape[1] | ||
) # or should this be 1? | ||
logger.info("building group for token slice [%s : %s]", group_start, group_end) | ||
groups.append(tokens.input_ids[:, group_start:group_end]) | ||
|
||
# encode each chunk | ||
logger.info("group token shapes: %s", [t.shape for t in groups]) | ||
group_embeds = [] | ||
for group in groups: | ||
logger.info("encoding group: %s", group.shape) | ||
embeds = self.text_encoder(input_ids=group.astype(np.int32))[0] | ||
group_embeds.append(embeds) | ||
|
||
# concat those embeds | ||
logger.info("group embeds shape: %s", [t.shape for t in group_embeds]) | ||
prompt_embeds = np.concatenate(group_embeds, axis=1) | ||
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) | ||
|
||
# get unconditional embeddings for classifier free guidance | ||
if do_classifier_free_guidance: | ||
uncond_tokens: List[str] | ||
if negative_prompt is None: | ||
uncond_tokens = [""] * batch_size | ||
elif type(prompt) is not type(negative_prompt): | ||
raise TypeError( | ||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | ||
f" {type(prompt)}." | ||
) | ||
elif isinstance(negative_prompt, str): | ||
uncond_tokens = [negative_prompt] * batch_size | ||
elif batch_size != len(negative_prompt): | ||
raise ValueError( | ||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | ||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | ||
" the batch size of `prompt`." | ||
) | ||
else: | ||
uncond_tokens = negative_prompt | ||
|
||
uncond_input = self.tokenizer( | ||
uncond_tokens, | ||
padding="max_length", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="np", | ||
) | ||
negative_prompt_embeds = self.text_encoder( | ||
input_ids=uncond_input.input_ids.astype(np.int32) | ||
)[0] | ||
negative_padding = tokens.input_ids.shape[1] - negative_prompt_embeds.shape[1] | ||
logger.info( | ||
"padding negative prompt to match input: %s, %s, %s extra tokens", | ||
tokens.input_ids.shape, | ||
negative_prompt_embeds.shape, | ||
negative_padding, | ||
) | ||
negative_prompt_embeds = np.pad( | ||
negative_prompt_embeds, | ||
[(0, 0), (0, negative_padding), (0, 0)], | ||
mode="constant", | ||
constant_values=0, | ||
) | ||
negative_prompt_embeds = np.repeat( | ||
negative_prompt_embeds, num_images_per_prompt, axis=0 | ||
) | ||
|
||
# For classifier free guidance, we need to do two forward passes. | ||
# Here we concatenate the unconditional and text embeddings into a single batch | ||
# to avoid doing two forward passes | ||
prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) | ||
|
||
logger.info("expanded prompt shape: %s", prompt_embeds.shape) | ||
return prompt_embeds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters