Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnimateDiff prompt travel #9231

Merged
merged 22 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,15 +972,32 @@ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
return frame_indices

def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if weighting_scheme == "flat":
weights = [1.0] * num_frames

elif weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
mid = num_frames // 2
weights = list(range(1, mid + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
mid = (num_frames + 1) // 2
weights = list(range(1, mid))
weights = weights + [mid] + weights[::-1]

elif weighting_scheme == "delayed_reverse_sawtooth":
if num_frames % 2 == 0:
# num_frames = 4 => [0.01, 2, 2, 1]
mid = num_frames // 2
weights = [0.01] * (mid - 1) + [mid]
weights = weights + list(range(mid, 0, -1))
else:
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = [0.01] * mid
weights = weights + list(range(mid, 0, -1))
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")

Expand Down
1 change: 0 additions & 1 deletion src/diffusers/models/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,6 @@ def forward(

emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
DN6 marked this conversation as resolved.
Show resolved Hide resolved

# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(

self.in_channels = in_channels

self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)

# 3. Define transformers blocks
Expand Down Expand Up @@ -2178,7 +2178,6 @@ def forward(

emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
DN6 marked this conversation as resolved.
Show resolved Hide resolved

if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
Expand Down
66 changes: 45 additions & 21 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
Expand Down Expand Up @@ -470,8 +469,8 @@ def check_inputs(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -557,11 +556,15 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt: Optional[Union[str, List[str]]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down Expand Up @@ -701,9 +704,10 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
Expand All @@ -716,22 +720,39 @@ def __call__(
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
Expand Down Expand Up @@ -783,6 +804,9 @@ def __call__(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ def check_inputs(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")

if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -699,6 +699,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
def __call__(
self,
Expand Down Expand Up @@ -858,9 +862,10 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
Expand All @@ -883,22 +888,39 @@ def __call__(
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
Expand Down Expand Up @@ -990,6 +1012,9 @@ def __call__(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand All @@ -1002,7 +1027,6 @@ def __call__(
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,8 @@ def __call__(
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,8 @@ def __call__(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)

# 4. Prepare IP-Adapter embeddings
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
Expand Down
Loading
Loading