Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CogVideoX-5b-I2V support #9418

Merged
merged 34 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
6e3ae04
draft Init
zRzRzRzRzRzRzR Sep 10, 2024
ad78738
draft
zRzRzRzRzRzRzR Sep 11, 2024
8966671
vae encode image
zRzRzRzRzRzRzR Sep 11, 2024
a56c510
Merge branch 'huggingface:main' into cogvideox-5b-i2v
zRzRzRzRzRzRzR Sep 12, 2024
c238fe2
make style
a-r-r-o-w Sep 12, 2024
c1f7a80
image latents preparation
a-r-r-o-w Sep 12, 2024
3df95b2
remove image encoder from conversion script
a-r-r-o-w Sep 12, 2024
677a553
fix minor bugs
a-r-r-o-w Sep 12, 2024
4f51829
make pipeline work
a-r-r-o-w Sep 12, 2024
33c7cd6
make style
a-r-r-o-w Sep 12, 2024
bc07f9f
remove debug prints
a-r-r-o-w Sep 12, 2024
98f1023
fix imports
a-r-r-o-w Sep 12, 2024
aa12e1b
update example
a-r-r-o-w Sep 12, 2024
1970f4f
make fix-copies
a-r-r-o-w Sep 12, 2024
e044850
add fast tests
a-r-r-o-w Sep 12, 2024
f7d8e37
Merge branch 'main' into cogvideox-5b-i2v
a-r-r-o-w Sep 12, 2024
9f6f3f6
fix import
a-r-r-o-w Sep 12, 2024
877cdc0
update vae
a-r-r-o-w Sep 13, 2024
29f1007
update docs
a-r-r-o-w Sep 13, 2024
0c1358c
update image link
a-r-r-o-w Sep 13, 2024
8222a55
apply suggestions from review
a-r-r-o-w Sep 13, 2024
61831bd
Merge branch 'main' into cogvideox-5b-i2v
a-r-r-o-w Sep 13, 2024
2d8dce9
apply suggestions from review
a-r-r-o-w Sep 13, 2024
4f89426
add slow test
a-r-r-o-w Sep 13, 2024
21a6f79
make use of learned positional embeddings
a-r-r-o-w Sep 13, 2024
6ce0778
apply suggestions from review
a-r-r-o-w Sep 13, 2024
7e637d6
Merge branch 'huggingface:main' into cogvideox-5b-i2v
zRzRzRzRzRzRzR Sep 13, 2024
6f313e8
doc change
zRzRzRzRzRzRzR Sep 14, 2024
ed8bda9
Merge branch 'main' into cogvideox-5b-i2v
a-r-r-o-w Sep 16, 2024
c8ec68c
Update convert_cogvideox_to_diffusers.py
zRzRzRzRzRzRzR Sep 16, 2024
33056c5
make style
a-r-r-o-w Sep 16, 2024
6dc9bdb
final changes
a-r-r-o-w Sep 16, 2024
edeb626
make style
a-r-r-o-w Sep 16, 2024
380a820
fix tests
a-r-r-o-w Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m

This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).

There are two models available that can be used with the CogVideoX pipeline:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.

There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.

## Inference

Expand Down Expand Up @@ -98,6 +101,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all
- __call__

## CogVideoXImageToVideoPipeline

[[autodoc]] CogVideoXImageToVideoPipeline
- all
- __call__
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

## CogVideoXVideoToVideoPipeline

[[autodoc]] CogVideoXVideoToVideoPipeline
Expand Down
30 changes: 24 additions & 6 deletions scripts/convert_cogvideox_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)


def reassign_query_key_value_inplace(key: str, state_dict: Dict[str, Any]):
Expand Down Expand Up @@ -89,6 +95,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"freqs_sin": remove_keys_inplace,
"freqs_cos": remove_keys_inplace,
"position_embedding": remove_keys_inplace,
# TODO zRzRzRzRzRzRzR: really need to remove?
"pos_embedding": remove_keys_inplace,
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
}

VAE_KEYS_RENAME_DICT = {
Expand Down Expand Up @@ -131,12 +139,14 @@ def convert_transformer(
num_layers: int,
num_attention_heads: int,
use_rotary_positional_embeddings: bool,
i2v: bool,
dtype: torch.dtype,
):
PREFIX_KEY = "model.diffusion_model."

original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
transformer = CogVideoXTransformer3DModel(
in_channels=32 if i2v else 16,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
Expand All @@ -153,7 +163,6 @@ def convert_transformer(
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True)
return transformer

Expand Down Expand Up @@ -205,6 +214,7 @@ def get_args():
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
return parser.parse_args()


Expand All @@ -225,6 +235,7 @@ def get_args():
args.num_layers,
args.num_attention_heads,
args.use_rotary_positional_embeddings,
args.i2v,
dtype,
)
if args.vae_ckpt_path is not None:
Expand All @@ -233,7 +244,6 @@ def get_args():
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)

# Apparently, the conversion does not work any more without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
Expand All @@ -252,9 +262,17 @@ def get_args():
"timestep_spacing": "trailing",
}
)

pipe = CogVideoXPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
if args.i2v:
pipeline_cls = CogVideoXImageToVideoPipeline
else:
pipeline_cls = CogVideoXPipeline

pipe = pipeline_cls(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)

if args.fp16:
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CycleDiffusionPipeline",
Expand Down Expand Up @@ -703,6 +704,7 @@
AudioLDMPipeline,
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CycleDiffusionPipeline,
Expand Down
14 changes: 10 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,8 +1089,10 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.tiled_encode(x)

frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
enc = []
for i in range(num_frames // frame_batch_size):
for i in range(num_batches):
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
Expand Down Expand Up @@ -1140,8 +1142,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
return self.tiled_decode(z, return_dict=return_dict)

frame_batch_size = self.num_latent_frames_batch_size
num_batches = num_frames // frame_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
for i in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
Expand Down Expand Up @@ -1233,8 +1236,10 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
time = []
for k in range(num_frames // frame_batch_size):
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
Expand Down Expand Up @@ -1309,8 +1314,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = num_frames // frame_batch_size
time = []
for k in range(num_frames // frame_batch_size):
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,11 @@ def custom_forward(*inputs):
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if not return_dict:
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
Expand Down Expand Up @@ -461,7 +465,7 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/cogvideo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -34,6 +35,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline

else:
Expand Down
Loading
Loading