Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion and SDXL Callbacks are fundamentally broken for prompt_embeds #9906

Open
AI-Casanova opened this issue Nov 11, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@AI-Casanova
Copy link

Describe the bug

I'm working on implementing scheduled prompting in SDNext and realized that the callback in SDXL pipelines is fundamentally not functional. This also applies to Stable Diffusion Pipelines and likely others.

prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)

Of the preceding lines the following variables are completely useless as they are not referenced inside the denoising loop and thus adjusting them in the callback can do nothing:

  • negative_prompt_embeds
  • negative_pooled_prompt_embeds
  • negative_add_time_ids

If self.do_classifier_free_guidance is True, the other variable are concatenations of themselves and their negative counterparts, and not what their variables state.

For the callback to operate intuitively, the code following # 7 should be split out into a separate function, and following the variables NOT overwritten:

  • prompt_embeds
  • add_text_embeds (naming convention should be fixed)
  • add_time_ids

That function should then be called post-callback to update all of the associated concat variables.

Reproduction

N/A

This is a code only issue.

Logs

No response

System Info

N/A

Who can help?

@yiyixuxu @sayakpaul

@AI-Casanova AI-Casanova added the bug Something isn't working label Nov 11, 2024
@AI-Casanova
Copy link
Author

AI-Casanova commented Nov 11, 2024

I can and will create a monkeypatch in my code for the time being, but that is a lot of code duplication for something that should be handled inside the pipelines.

(OK the monkeypatch isn't that bad, but the problem remains)

@yiyixuxu
Copy link
Collaborator

can you provide the callback example that you're trying to implement?
it is true that the negative* embeddings are not used inside the denoising loop but it will be passed to your callback function and you will be able to use it inside

@vladmandic
Copy link
Contributor

the point of the callback is to be two way so values can be changed. if some values are two way and some are one way only, it's bad.

what we're trying to do is to alter emdeds per step.

@AI-Casanova
Copy link
Author

The problem comes when attempting to feed prompt embeddings back into the denoising loop, here's an attempt at non-optimized code:

def callback_fn(pipe, i, t, callback_kwargs):
    if i > 10:  # arbitrary set point
        prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt("Prompt for last half of denoising")
        callback_kwargs["prompt_embeds"] = prompt_embeds
        callback_kwargs["negative_prompt_embeds"] = negative_prompt_embeds
        callback_kwargs["add_text_embeds"] = pooled_prompt_embeds
        callback_kwargs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
    return callback_kwargs

This works if CFG is disabled, because the negatives are ignored

If CFG is enabled:

        callback_kwargs["prompt_embeds"] = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
        callback_kwargs["add_text_embeds"] = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)

works, but again fails if CFG is disabled, so the callback is required to determine

@AI-Casanova
Copy link
Author

for reference
image

Switching between a woman sitting on a rotten log and a woman sitting on a park bench
20 steps total switching at the indicated step

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants