Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Sep 5, 2024
1 parent 661a0b3 commit c55a50a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
58 changes: 58 additions & 0 deletions docs/source/en/api/pipelines/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,64 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
</tr>
</table>

## Using FreeNoise

[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu.

FreeNoise is a sampling mechanism that allows the generation of longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper.

```python
import torch
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter
from diffusers.utils import export_to_video, load_image

# Load pipeline
dtype = torch.float16
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)

pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

pipe.load_lora_weights(
"wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora"
)
pipe.set_adapters(["lcm_lora"], [0.8])

# Enable FreeNoise for long prompt generation
pipe.enable_free_noise(context_length=16, context_stride=4)
pipe.to("cuda")

# Optionally, enable memory efficient inference
pipe.enable_free_noise_split_inference()
pipe.unet.enable_forward_chunking(16)

# Can be a single prompt, or a dictionary with frame timesteps
prompt = {
0: "A caterpillar on a leaf, high quality, photorealistic",
40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic",
80: "A cocoon on a leaf, flowers in the backgrond, photorealistic",
120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic",
160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic",
200: "A beautiful butterfly, flying away in a forest, photorealistic",
240: "A cyberpunk butterfly, neon lights, glowing",
}
negative_prompt = "bad quality, worst quality, jpeg artifacts"

# Run inference
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=256,
guidance_scale=2.5,
num_inference_steps=10,
generator=torch.Generator("cpu").manual_seed(0),
)

# Save video
frames = output.frames[0]
export_to_video(frames, "output.mp4", fps=16)
```

## Using `from_single_file` with the MotionAdapter

Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/pipelines/free_noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,21 @@ def _enable_split_inference_samplers_(
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])

def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
r"""
Enable FreeNoise memory optimizations by utilizing
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
Args:
spatial_split_size (`int`, defaults to `256`):
The split size across spatial dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
modeling blocks.
temporal_split_size (`int`, defaults to `16`):
The split size across temporal dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
attention, resnets, downsampling and upsampling blocks.
"""
# TODO(aryan): Discuss on what's the best way to provide more control to users
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
Expand Down

0 comments on commit c55a50a

Please sign in to comment.