Skip to content

Commit

Permalink
added dit llama models
Browse files Browse the repository at this point in the history
Signed-off-by: Zeeshan Patel <[email protected]>
  • Loading branch information
zpx01 committed Oct 11, 2024
1 parent a42f7d2 commit f706757
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 0 deletions.
13 changes: 13 additions & 0 deletions nemo/collections/diffusion/models/dit_llama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
173 changes: 173 additions & 0 deletions nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Literal

from megatron.core.transformer.attention import (
CrossAttention,
CrossAttentionSubmodules,
SelfAttention,
SelfAttentionSubmodules,
)
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TERowParallelLinear,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_block import TransformerConfig
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.utils import make_viewless_tensor

from nemo.collections.diffusion.models.dit.dit_layer_spec import AdaLN


class MoviegGenLayer(TransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
DiT with Adapative Layer Normalization.
"""

def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
):
def _replace_no_cp_submodules(submodules):
modified_submods = copy.deepcopy(submodules)
modified_submods.cross_attention = IdentityOp
# modified_submods.temporal_self_attention = IdentityOp
return modified_submods

# Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init.
modified_submods = _replace_no_cp_submodules(submodules)
super().__init__(
config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout
)

# Override Cross Attention to disable CP.
# Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes.
cp_override_config = copy.deepcopy(config)
cp_override_config.context_parallel_size = 1
cp_override_config.tp_comm_overlap = False
self.cross_attention = build_module(
submodules.cross_attention,
config=cp_override_config,
layer_number=layer_number,
)

self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6) # , norm=TENorm)

def forward(
self,
hidden_states,
attention_mask,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# timestep embedding
timestep_emb = attention_mask
factorized_pos_emb = rotary_pos_emb
hidden_states = hidden_states + factorized_pos_emb

# ******************************************** full self attention ******************************************************
shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb)

# adaLN with scale + shift
pre_full_attn_layernorm_output_ada = self.adaLN.modulated_layernorm(
hidden_states, shift=shift_full, scale=scale_full
)

attention_output, _ = self.self_attention(
pre_full_attn_layernorm_output_ada,
attention_mask=None,
packed_seq_params=None if packed_seq_params is None else packed_seq_params['self_attention'],
)

hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full)

# ******************************************** cross attention ******************************************************
attention_output, _ = self.cross_attention(
hidden_states,
attention_mask=context_mask,
key_value_states=context,
packed_seq_params=None if packed_seq_params is None else packed_seq_params['cross_attention'],
)

# ******************************************** mlp ******************************************************
pre_mlp_layernorm_output_ada = self.adaLN.modulated_layernorm(
attention_output, shift=shift_mlp, scale=scale_mlp
)

mlp_output, _ = self.mlp(pre_mlp_layernorm_output_ada)
hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp)

# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True)

return output, context


def get_dit_llama_spec() -> ModuleSpec:
params = {"attn_mask_type": AttnMaskType.padding}
return ModuleSpec(
module=MoviegGenLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params=params,
submodules=SelfAttentionSubmodules(
linear_qkv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
cross_attention=ModuleSpec(
module=CrossAttention,
params=params,
submodules=CrossAttentionSubmodules(
linear_q=TEColumnParallelLinear,
linear_kv=TEColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear,
linear_fc2=TERowParallelLinear,
),
),
),
)
60 changes: 60 additions & 0 deletions nemo/collections/diffusion/models/dit_llama/dit_llama_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Literal

from megatron.core.transformer.transformer_config import TransformerConfig

from nemo.collections.diffusion.models.dit import dit_embeddings
from nemo.collections.diffusion.models.dit.dit_model import DiTCrossAttentionModel
from nemo.collections.diffusion.models.dit_llama.dit_llama_layer_spec import get_dit_llama_spec


class DiTLlamaModel(DiTCrossAttentionModel):
def __init__(
self,
config: TransformerConfig,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
parallel_output: bool = True,
position_embedding_type: Literal["learned_absolute", "rope"] = "rope",
max_img_h: int = 80,
max_img_w: int = 80,
max_frames: int = 34,
patch_spatial: int = 1,
patch_temporal: int = 1,
in_channels: int = 16,
out_channels: int = 16,
**kwargs,
):
super().__init__(
config=config,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=fp16_lm_cross_entropy,
parallel_output=parallel_output,
position_embedding_type=position_embedding_type,
max_img_h=max_img_h,
max_img_w=max_img_w,
max_frames=max_frames,
patch_spatial=patch_spatial,
patch_temporal=patch_temporal,
in_channels=in_channels,
out_channels=out_channels,
transformer_decoder_layer_spec=get_dit_llama_spec,
pos_embedder=dit_embeddings.FactorizedLearnable3DEmbedding,
**kwargs,
)

0 comments on commit f706757

Please sign in to comment.