Skip to content

Commit

Permalink
CogVideoX-5b-I2V support (#9418)
Browse files Browse the repository at this point in the history
* draft Init

* draft

* vae encode image

* make style

* image latents preparation

* remove image encoder from conversion script

* fix minor bugs

* make pipeline work

* make style

* remove debug prints

* fix imports

* update example

* make fix-copies

* add fast tests

* fix import

* update vae

* update docs

* update image link

* apply suggestions from review

* apply suggestions from review

* add slow test

* make use of learned positional embeddings

* apply suggestions from review

* doc change

* Update convert_cogvideox_to_diffusers.py

* make style

* final changes

* make style

* fix tests

---------

Co-authored-by: Aryan <[email protected]>
  • Loading branch information
zRzRzRzRzRzRzR and a-r-r-o-w authored Sep 16, 2024
1 parent 2171f77 commit 8336405
Show file tree
Hide file tree
Showing 12 changed files with 1,328 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/api/loaders/single_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
## Supported pipelines

- [`CogVideoXPipeline`]
- [`CogVideoXImageToVideoPipeline`]
- [`CogVideoXVideoToVideoPipeline`]
- [`StableDiffusionPipeline`]
- [`StableDiffusionImg2ImgPipeline`]
- [`StableDiffusionInpaintPipeline`]
Expand Down
30 changes: 22 additions & 8 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

There are two models available that can be used with the CogVideoX pipeline:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.

There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.

## Inference

Expand All @@ -41,10 +44,15 @@ First, load the pipeline:

```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
from diffusers.utils import export_to_video,load_image
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b"
```

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
If you are using the image-to-video pipeline, load it as follows:

```python
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda")
```

Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
Expand All @@ -53,7 +61,7 @@ Then change the memory layout of the pipelines `transformer` component to `torch
pipe.transformer.to(memory_format=torch.channels_last)
```

Finally, compile the components and run inference:
Compile the components and run inference:

```python
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
Expand All @@ -63,7 +71,7 @@ prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wood
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```

The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:

```
Without torch.compile(): Average inference time: 96.89 seconds.
Expand Down Expand Up @@ -98,6 +106,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all
- __call__

## CogVideoXImageToVideoPipeline

[[autodoc]] CogVideoXImageToVideoPipeline
- all
- __call__

## CogVideoXVideoToVideoPipeline

[[autodoc]] CogVideoXVideoToVideoPipeline
Expand Down
36 changes: 29 additions & 7 deletions scripts/convert_cogvideox_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)


def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
Expand Down Expand Up @@ -78,6 +84,7 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"mixins.final_layer.norm_final": "norm_out.norm",
"mixins.final_layer.linear": "proj_out",
"mixins.final_layer.adaLN_modulation.1": "norm_out.linear",
"mixins.pos_embed.pos_embedding": "patch_embed.pos_embedding", # Specific to CogVideoX-5b-I2V
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
Expand Down Expand Up @@ -131,15 +138,18 @@ def convert_transformer(
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."

original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel(
in_channels=32 if i2v else 16,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
use_learned_positional_embeddings=i2v,
).to(dtype=dtype)

for key in list(original_state_dict.keys()):
Expand All @@ -153,7 +163,6 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
return transformer

Expand Down Expand Up @@ -205,6 +214,7 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
return parser.parse_args()


Expand All @@ -225,6 +235,7 @@ def get_args():
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
)
if args.vae_ckpt_path is not None:
Expand All @@ -234,7 +245,7 @@ def get_args():
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

# Apparently, the conversion does not work any more without this :shrug:
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()

Expand All @@ -252,9 +263,17 @@ def get_args():
"timestep_spacing": "trailing",
}
)

pipe = CogVideoXPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
if args.i2v:
pipeline_cls = CogVideoXImageToVideoPipeline
else:
pipeline_cls = CogVideoXPipeline

pipe = pipeline_cls(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)

if args.fp16:
Expand All @@ -265,4 +284,7 @@ def get_args():
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
# is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)

# This is necessary This is necessary for users with insufficient memory,
# such as those using Colab and notebooks, as it can save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CycleDiffusionPipeline",
Expand Down Expand Up @@ -703,6 +704,7 @@
AudioLDMPipeline,
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CycleDiffusionPipeline,
Expand Down
14 changes: 10 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,8 +1089,10 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.tiled_encode(x)

frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
enc = []
for i in range(num_frames // frame_batch_size):
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
Expand Down Expand Up @@ -1140,8 +1142,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
return self.tiled_decode(z, return_dict=return_dict)

frame_batch_size = self.num_latent_frames_batch_size
num_batches = num_frames // frame_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
Expand Down Expand Up @@ -1233,8 +1236,10 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
time = []
for k in range(num_frames // frame_batch_size):
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
Expand Down Expand Up @@ -1309,8 +1314,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = num_frames // frame_batch_size
time = []
for k in range(num_frames // frame_batch_size):
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
Expand Down
16 changes: 13 additions & 3 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
use_learned_positional_embeddings: bool = True,
) -> None:
super().__init__()

Expand All @@ -363,15 +364,17 @@ def __init__(
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings

self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)

if use_positional_embeddings:
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)

def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
Expand Down Expand Up @@ -415,8 +418,15 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]

if self.use_positional_embeddings:
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)

pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

if (
self.sample_height != height
or self.sample_width != width
Expand Down
14 changes: 13 additions & 1 deletion src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,18 @@ def __init__(
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_rotary_positional_embeddings: bool = False,
use_learned_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim

if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
raise ValueError(
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
"issue at https://github.com/huggingface/diffusers/issues."
)

# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
Expand All @@ -254,6 +262,7 @@ def __init__(
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)

Expand Down Expand Up @@ -465,8 +474,11 @@ def custom_forward(*inputs):
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if not return_dict:
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
Expand Down Expand Up @@ -461,7 +465,7 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/cogvideo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -34,6 +35,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline

else:
Expand Down
Loading

0 comments on commit 8336405

Please sign in to comment.