Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] CogVideoX followups + tiled decoding support #9150

Merged
merged 19 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions docs/source/en/api/pipelines/cogvideox.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRz

## Inference

Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.

First, load the pipeline:
Inference can be run with CogVideoX by following the code below:
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

```python
import torch
Expand All @@ -55,29 +53,29 @@ video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```

Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.

Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:

```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
```

Finally, compile the components and run inference:

```python
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)

# CogVideoX works very well with long and well-described prompts
# CogVideoX works well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```

The [benchmark](TODO: link) results on an 80GB A100 machine are:
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:

```
Without torch.compile(): Average inference time: TODO seconds.
With torch.compile(): Average inference time: TODO seconds.
Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: 84.60 seconds.
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
```

## CogVideoXPipeline
Expand Down
216 changes: 186 additions & 30 deletions src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,30 @@ def __init__(
self.conv_cache = None

def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
kernel_size = self.time_kernel_size
if kernel_size == 1:
return inputs

inputs = inputs.transpose(0, dim)

if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)

inputs = inputs.transpose(0, dim).contiguous()
if kernel_size > 1:
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
return inputs

def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
input_parallel = self.fake_context_parallel_forward(inputs)
inputs = self.fake_context_parallel_forward(inputs)

self._clear_fake_context_parallel_cache()
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()

padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)

output_parallel = self.conv(input_parallel)
output = output_parallel
output = self.conv(inputs)
return output


Expand Down Expand Up @@ -911,7 +905,8 @@ def __init__(
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_size: int = 256,
sample_height: int = 480,
sample_width: int = 720,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
Expand Down Expand Up @@ -950,25 +945,64 @@ def __init__(
self.use_slicing = False
self.use_tiling = False

self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.33
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value

def clear_fake_context_parallel_cache(self):
def _clear_fake_context_parallel_cache(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better!

for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))

def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False

def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
Expand All @@ -993,8 +1027,34 @@ def encode(
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)

def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_width or z.shape[-2] > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)

num_latent_frames = z.shape[2]
dec = []
for i in range(num_latent_frames // 2):
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
if num_latent_frames % 2 == 0:
start_frame, end_frame = (2 * i, 2 * i + 2)
else:
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)

z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)

self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.

Expand All @@ -1007,13 +1067,109 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample

if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z)
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB

overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_height):
a-r-r-o-w marked this conversation as resolved.
Show resolved Hide resolved
row = []
for j in range(0, z.shape[4], overlap_width):
time = []
for k in range(z.shape[2] // 2):
if z.shape[2] % 2 == 0:
start_frame, end_frame = (2 * k, 2 * k + 2)
else:
start_frame, end_frame = (0, 3) if k == 0 else (2 * k + 1, 2 * k + 3)
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))

dec = torch.cat(result_rows, dim=3)

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def forward(
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]

# 5. Transformer blocks
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

Expand All @@ -331,11 +331,11 @@ def custom_forward(*inputs):

hidden_states = self.norm_final(hidden_states)

# 6. Final block
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 7. Unpatchify
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
Expand Down
Loading
Loading