-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AnimateDiff prompt travel #9231
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
hi! I tried to recreate the video with the girl and seasons but i got an error like this: here is my notebook |
@aycax Thanks for testing! I recently made some updates to the chunked inference code design. You will now have to do for the current version of the PR: context_length = 16
context_stride = 4
pipe.enable_free_noise(context_length=context_length, context_stride=context_stride)
pipe.enable_free_noise_chunked_inference()
pipe.unet.enable_forward_chunking(context_length) Or, you could install the version of this PR before this commit. Since this PR is a work-in-progress, some things might be changed unexpectedly, but I'll be sure to update all the example code when ready to reflect correct usage. |
src/diffusers/models/attention.py
Outdated
@@ -1087,8 +1104,15 @@ def forward( | |||
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights | |||
num_times_accumulated[:, frame_start:frame_end] += weights | |||
|
|||
hidden_states = torch.where( | |||
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values | |||
hidden_states = torch.cat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to include this change in the memory optimisations no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Will revert here
@@ -69,6 +70,9 @@ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, Dow | |||
motion_module.transformer_blocks[i].load_state_dict( | |||
basic_transfomer_block.state_dict(), strict=True | |||
) | |||
motion_module.transformer_blocks[i].set_chunk_feed_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also probably better to include in the memory optimisations PR?
What does this PR do?
Adds support for prompt travel to AnimateDiff pipelines.
Examples
The following are some bare-minimum examples that demonstrate the expected usage of the new features. Note that for latent upscaling, we naively upscale latents and don't make use of a model here (which could be something to explore for the reader as a more complex workflow). Combined with other pipelines and techniques, one can generate really cool animations.
Text-to-Video Prompt Travel
animatediff_multiprompt_2.webm
Code
Text-to-Video Prompt Travel + Latent Upscale
animatediff_multiprompt_1.webm
animatediff_multiprompt_1_latent_upscaled.webm
Code
Image-to-Video Prompt Travel
animatediff_ipadapter_multiimage.webm
Code
Image-to-Video Prompt Travel + Latent Upscale
animatediff_ipadapter_multiimage_latent_upscaled.webm
Code
Video-to-Video Prompt Travel + ControlNet
TODO: ControlNet has not been optimized for batched inference yet. This will be updated soon.
Code
Video-to-Video Prompt Travel + ControlNet + Latent Upscale
TODO: ControlNet has not been optimized for batched inference yet. This will be updated soon.
Code
Frame interpolation
TODO: SparseCtrl is not supported or optimized for batched inference yet. This will be updated soon.
Code
Memory optimizations
Nothing too fancy here. To lower memory usage, chunking is performed across the spatial batch in motion blocks, across temporal batch in transformer blocks, resnet, upsampling and downsampling blocks, and across spatial/temporal batch in attention feed-forward chunking. This is mostly possible due to the normalization layers being either LayerNorm or GroupNorm which play well for chunked inferenced across batch dimensions. To enable memory optimizations, the following are required (at the time of making the PR - will be updated after reviews):
pipe.unet.enable_attn_chunking
: Chunking across temporal batches when passing through spatial attention blockspipe.unet.enable_motion_module_chunking
: Chunking across spatial batches when passing through temporal attention blockspipe.unet.enable_resnet_chunking
: Chunking across resnet layerspipe.unet.enable_forward_chunking
: Chunking across attention FeedForward layersThe main pain points, as observed from the memory trace spikes, were the attention layers, resnet layers, upsampling/downsampling layers and a call to torch.where. After improvements, the memory spikes are flattened out but there is still room for improvement with offloading and better batching across all intermediate layers that are not handled perfectly yet. The end goal is to run techniques like FreeNoise in a manner such that the total memory depends only on the context length and scales linearly or remains constant with number of frames.
Adding my memory opt. logs for some things that worked for anyone who's interested in the story from 25 GB for 128 frames to 12 GB by making inference-only changes to the UNet layers. Many other things I tried were very stupid to make the list and didn't work, so those have been omitted. It goes withotu saying that these can be combined with other memory opt. techniques to further reduce overall usage.
log
Using LoRAs
Since enabling FreeNoise replaces BasicTransformerBlock with FreeNoiseTransformerBlock, any LoRAs loaded into the attention QKV or output projection layers will not be usable. This is because the LoRA config information is not readily available, and calls for some tedious implementation. The easiest way to not have LoRA-related state dict loading failures is to enable FreeNoise BEFORE loading any loras.
Additionally, you could load, fuse and then unload the loras. With this approach, the ordering of the
enable_free_noise
call would not matter.Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6