Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about classifier guidance for image in training code #54

Open
tykim0507 opened this issue May 13, 2024 · 3 comments
Open

Question about classifier guidance for image in training code #54

tykim0507 opened this issue May 13, 2024 · 3 comments

Comments

@tykim0507
Copy link

Hello, nice work on the training code! Thank you for sharing this code.
I have a question about your image conditioned classifier guidance in your code.

if args.conditioning_dropout_prob is not None:
  random_p = torch.rand(
      bsz, device=latents.device, generator=generator)
  # Sample masks for the edit prompts.
  prompt_mask = random_p < 2 * args.conditioning_dropout_prob
  prompt_mask = prompt_mask.reshape(bsz, 1, 1)
  # Final text conditioning.
  null_conditioning = torch.zeros_like(encoder_hidden_states)
  encoder_hidden_states = torch.where(
      prompt_mask, null_conditioning.unsqueeze(1), encoder_hidden_states.unsqueeze(1))
  # Sample masks for the original images.
  image_mask_dtype = conditional_latents.dtype
  image_mask = 1 - (
      (random_p >= args.conditioning_dropout_prob).to(
          image_mask_dtype)
      * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
  )
  image_mask = image_mask.reshape(bsz, 1, 1, 1)
  # Final image conditioning.
  conditional_latents = image_mask * conditional_latents

I wonder this is an official way of implementing the classifier free guidance for image conditions. If the drop prob is 0.1 as default,

with prob 0.1: first frame concat remains, first frame for cross attention is 0
with prob 0.1: first frame concat is 0, first frame for cross attention is 0
with prob 0.1: first frame concat is 0, first frame for cross attention remains
with prob 0.1: first frame concat remains, first frame for cross attention remains

Is this as your intention?

Thank you

@xiangweifeng
Copy link

@tykim0507
Copy link
Author

yeah that's true, and I am curious about the way to enable cfg.
20% probability of drop-out seems reasonable, and it can be viewed as reasonable to make the dropout of the image prompt and the concated image occur not always simultaneously.
I'm just curious if this is the official implementation of enabling image prompt cfg!

@xiangweifeng
Copy link

refer 3.2.3 of Hierarchical Masked 3D Diffusion Model for Video Outpainting. I think this training code adopt two cfg so that correspondent changes should be in the inference stage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants