Skip to content

Commit

Permalink
new vfm training features (#11246)
Browse files Browse the repository at this point in the history
Signed-off-by: Zeeshan Patel <[email protected]>
Co-authored-by: Zeeshan Patel <[email protected]>
  • Loading branch information
zpx01 and Zeeshan Patel authored Nov 13, 2024
1 parent 085e957 commit 6e8e974
Show file tree
Hide file tree
Showing 11 changed files with 1,097 additions and 114 deletions.
58 changes: 46 additions & 12 deletions nemo/collections/diffusion/data/diffusion_energon_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import logging
from typing import Any, Dict, Literal

from megatron.energon import DefaultTaskEncoder, get_train_dataset
from megatron.core import parallel_state
from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS

from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule
Expand Down Expand Up @@ -56,6 +57,9 @@ def __init__(
pin_memory: bool = True,
task_encoder: DefaultTaskEncoder = None,
use_train_split_for_val: bool = False,
virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning
packing_buffer_size: int | None = None,
max_samples_per_sequence: int | None = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Expand All @@ -82,6 +86,10 @@ def __init__(
task_encoder=task_encoder,
)
self.use_train_split_for_val = use_train_split_for_val
self.virtual_epoch_length = virtual_epoch_length
self.num_workers_val = 1
self.packing_buffer_size = packing_buffer_size
self.max_samples_per_sequence = max_samples_per_sequence

def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'):
"""
Expand All @@ -106,29 +114,55 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val
batch_size=self.micro_batch_size,
task_encoder=self.task_encoder,
worker_config=worker_config,
max_samples_per_sequence=None,
shuffle_buffer_size=100,
max_samples_per_sequence=self.max_samples_per_sequence,
shuffle_buffer_size=None,
split_part=split,
batch_drop_last=True,
virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning
virtual_epoch_length=self.virtual_epoch_length,
packing_buffer_size=self.packing_buffer_size,
)
return _dataset

def val_dataloader(self) -> EVAL_DATALOADERS:
"""
Configure the validation DataLoader.
Initialize and return the validation DataLoader.
This method configures the DataLoader for validation data.
Parameters:
worker_config: Configuration for the data loader workers.
This method initializes the DataLoader for the validation dataset. It ensures that the parallel state
is initialized correctly for distributed training and returns a configured DataLoader object.
Returns:
DataLoader: The DataLoader for validation data.
EVAL_DATALOADERS: The DataLoader for the validation dataset.
"""
if self.use_train_split_for_val:
return self.train_dataloader()
return super().val_dataloader()
if self.val_dataloader_object:
return self.val_dataloader_object

if not parallel_state.is_initialized():
message = (
"Muiltimodal val data loader parallel state is not initialized "
f"using default worker config with no_workers {self.num_workers}"
)
logging.info(message)

worker_config = WorkerConfig.default_worker_config(self.num_workers_val)
else:
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_group = parallel_state.get_data_parallel_group()

logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}")
worker_config = WorkerConfig(
rank=rank,
world_size=world_size,
num_workers=self.num_workers_val,
data_parallel_group=data_parallel_group,
worker_debug_path=None,
worker_log_level=0,
)
val_dataset = self.datasets_provider(worker_config, split='val')
energon_loader = get_savable_loader(val_dataset, worker_config=worker_config)
self.val_dataloader_object = energon_loader
return self.val_dataloader_object

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Expand Down
218 changes: 218 additions & 0 deletions nemo/collections/diffusion/data/diffusion_fake_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader

from nemo.collections.diffusion.models.model import DiTConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler

from .diffusion_taskencoder import pos_id_3d


class PosEmb3D:
"""Generates and provides 3D positional embeddings for video data."""

def __init__(self, *, max_t=96, max_h=960, max_w=960):
self.max_t = max_t
self.max_h = max_h
self.max_w = max_w
self.generate_pos_id()

def generate_pos_id(self):
"""Generates the positional ID grid based on max_t, max_h, and max_w."""
self.grid = torch.stack(
torch.meshgrid(
torch.arange(self.max_t, device='cpu'),
torch.arange(self.max_h, device='cpu'),
torch.arange(self.max_w, device='cpu'),
),
dim=-1,
)

def get_pos_id_3d(self, *, t, h, w):
"""Retrieves a subset of the positional IDs for the specified dimensions.
Parameters:
t (int): Number of time frames.
h (int): Height dimension.
w (int): Width dimension.
Returns:
torch.Tensor: The positional IDs tensor with shape (t, h, w, 3).
"""
if t > self.max_t or h > self.max_h or w > self.max_w:
self.max_t = max(self.max_t, t)
self.max_h = max(self.max_h, h)
self.max_w = max(self.max_w, w)
self.generate_pos_id()
return self.grid[:t, :h, :w]


class DiTVideoLatentFakeDataset(torch.utils.data.Dataset):
"""A fake dataset for generating synthetic video latent data."""

def __init__(
self,
n_frames,
max_h,
max_w,
patch_size,
in_channels,
crossattn_emb_size,
max_text_seqlen=512,
seq_length=8192,
):
self.max_t = n_frames
self.max_height = max_h
self.max_width = max_w
self.patch_size = patch_size
self.in_channels = in_channels
self.text_dim = crossattn_emb_size
self.text_seqlen = max_text_seqlen
self.seq_length = seq_length

def __len__(self):
"""Returns the total number of samples."""
return 100000000

def __getitem__(self, idx):
"""Generates a single sample of data.
Parameters:
idx (int): Index of the data sample.
Returns:
dict: A dictionary containing video latent data and related information.
"""
t = self.max_t
h = self.max_height
w = self.max_width
p = self.patch_size
c = self.in_channels

video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5
text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16)
pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3)

return {
'video': video_latent,
't5_text_embeddings': text_embedding,
'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(),
'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(),
'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32),
'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16),
}

def _collate_fn(self, batch):
"""A default implementation of a collation function.
Users should override this method to define custom data loaders.
"""
return torch.utils.data.dataloader.default_collate(batch)

def collate_fn(self, batch):
"""Method that user passes as a functor to DataLoader.
The method optionally performs neural type checking and adds types to the outputs.
Please note, subclasses of Dataset should not implement `input_types`.
Usage:
dataloader = torch.utils.data.DataLoader(
....,
collate_fn=dataset.collate_fn,
....
)
Returns:
Collated batch, with or without types.
"""
return self._collate_fn(batch)


class VideoLatentFakeDataModule(pl.LightningDataModule):
"""A LightningDataModule for generating fake video latent data for training."""

def __init__(
self,
model_config: DiTConfig,
seq_length: int = 2048,
micro_batch_size: int = 1,
global_batch_size: int = 8,
num_workers: int = 1,
pin_memory: bool = True,
task_encoder=None,
use_train_split_for_val: bool = False,
) -> None:
super().__init__()
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers
self.model_config = model_config

self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
)

def setup(self, stage: str = "") -> None:
"""Sets up the dataset for training and validation.
Parameters:
stage (str): Optional stage argument (unused).
"""
self._train_ds = DiTVideoLatentFakeDataset(
n_frames=self.model_config.max_frames,
max_h=self.model_config.max_img_h,
max_w=self.model_config.max_img_w,
patch_size=self.model_config.patch_spatial,
in_channels=self.model_config.in_channels,
crossattn_emb_size=self.model_config.crossattn_emb_size,
)

def train_dataloader(self) -> TRAIN_DATALOADERS:
"""Returns the training DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)

def val_dataloader(self) -> EVAL_DATALOADERS:
"""Returns the validation DataLoader."""
if not hasattr(self, "_train_ds"):
self.setup()
return self._create_dataloader(self._train_ds)

def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
"""Creates a DataLoader for the given dataset.
Parameters:
dataset (Dataset): The dataset to load.
**kwargs: Additional arguments for DataLoader.
Returns:
DataLoader: The DataLoader instance.
"""
return DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=True,
persistent_workers=True,
collate_fn=dataset.collate_fn,
**kwargs,
)
Loading

0 comments on commit 6e8e974

Please sign in to comment.