Skip to content

Commit

Permalink
Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable…
Browse files Browse the repository at this point in the history
… config selectable Norm Type (#181)

This PR has multiple aspects:
1 - Adds a new Triton based Fused RMSNorm I wrote. I've verified it's
numerical accuracy on both forward and backward with a unit test.
It improves MFU by +15% with FSDP2 7B, and compiled slightly by +1.2%:
<img width="545" alt="Screenshot 2024-03-29 at 5 18 14 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/8f16fae9-947b-4720-a370-b954779c33a7">

2 - Adds norms.py to house all 4 norm types, and standardizes to
[layernorm / np_layernorm / rmsnorm / fused_rmsnorm]. Norms.py has a
create_norms function that then creates the appropriate norm.

3 - Adds np_layernorm, which is layernorm with no affine transformation.

4 - Updates model.py to now support plug and play of any supported norm.

Thus instead of this type of if/then logic in the model class:
<img width="928" alt="Screenshot 2024-03-30 at 1 52 07 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/ba7cb976-580f-4471-a79b-a584f7d20693">

We simply have this:
<img width="1129" alt="Screenshot 2024-03-30 at 1 52 23 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/aba48b4d-1620-4059-840d-e620468f00f2">

This then allows for easy plug and play of any norm type with no
fiddling around in the model code.

5 - updates run_llama_train.sh to randomly select a port vs previous
fixed port number. (thanks @yifuwang for this tip!)


6 - Now users can quickly select the norm of their choice via the config
file:
<img width="774" alt="Screenshot 2024-03-30 at 3 01 43 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/3238b375-dc21-4ee2-a5fa-f6571da79edb">

7 - adds a NotImpl error if users try to run TP + fused_rnsmorm to avoid
any confusion (per @tianyu-l feedback):
~~~
NotImplementedError: fused_rmsnorm not yet compatible with TP. Please
use rmsnorm.
~~~
  • Loading branch information
lessw2020 authored Apr 5, 2024
1 parent 5e729a0 commit e218fb3
Show file tree
Hide file tree
Showing 11 changed files with 346 additions and 56 deletions.
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ max-line-length = 120
# N812 ignored because import torch.nn.functional as F is PyTorch convention
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
6 changes: 6 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def __init__(self):
default="debugmodel",
help="which model config to train",
)
self.parser.add_argument(
"--model.norm_type",
type=str,
default="rmsnorm",
help="Layer Normalization type to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
type=str,
Expand Down
67 changes: 13 additions & 54 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtrain.models.norms import create_norm


@dataclass
Expand All @@ -25,57 +26,7 @@ class ModelArgs:
depth_init: bool = (
True # initialization uses each unique layer_id or total model layer count
)


class RMSNorm(torch.nn.Module):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight

def reset_parameters(self):
torch.nn.init.ones_(self.weight)
norm_type: str = "rmsnorm"


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
Expand Down Expand Up @@ -381,8 +332,13 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)

self.attention_norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
Expand Down Expand Up @@ -447,7 +403,10 @@ def __init__(self, model_args: ModelArgs):
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

Expand Down
Loading

0 comments on commit e218fb3

Please sign in to comment.