Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable config selectable Norm Type #181

Merged
merged 31 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
457abee
Create fused_rms_norm.py
lessw2020 Mar 29, 2024
a2435df
start unifying all norms with single base class
lessw2020 Mar 30, 2024
de2a41b
add strEnum for NormTypes
lessw2020 Mar 30, 2024
e8782f1
link norm type control to model config from train config
lessw2020 Mar 30, 2024
d797d27
link norm type control to model config from train config
lessw2020 Mar 30, 2024
81f3402
all working
lessw2020 Mar 30, 2024
a1b45ba
linting
lessw2020 Mar 30, 2024
0c0d41f
linting, remove triton file from linting check
lessw2020 Mar 30, 2024
47a5791
add default rms in config_manager
lessw2020 Mar 30, 2024
be88137
lint default rms in config_manager, update 7b, 70b toml
lessw2020 Mar 30, 2024
ff37024
remove alternative backward pass section (saved in kernel repo)
lessw2020 Mar 31, 2024
7c0c62b
revert reshard_after_forward from MFU tuning
lessw2020 Apr 1, 2024
f81ff3f
merge all norms into single file per PR feedback
lessw2020 Apr 2, 2024
87def6a
move norms.py into models dir
lessw2020 Apr 2, 2024
df29e6e
remove StrEnum, add localhost:0
lessw2020 Apr 4, 2024
aef4e0c
change reshape to view in triton code
lessw2020 Apr 5, 2024
9088524
standardize norms to 4 express names
lessw2020 Apr 5, 2024
d32fe71
remove normbase
lessw2020 Apr 5, 2024
da4bd20
correct merge conflict in run.sh
lessw2020 Apr 5, 2024
310be84
update to remove debug2d.toml
lessw2020 Apr 5, 2024
6747119
correct mergeconflict 2 in run.sh
lessw2020 Apr 5, 2024
2d46786
Merge branch 'main' into fused_rms_norm
lessw2020 Apr 5, 2024
83ec4b6
linting
lessw2020 Apr 5, 2024
8750043
add comment regarding flake8 N803,N806
lessw2020 Apr 5, 2024
701e0d3
add check for fused_rmsnorm when TP active.
lessw2020 Apr 5, 2024
2a3ac78
linting
lessw2020 Apr 5, 2024
d5b4b32
formatting
lessw2020 Apr 5, 2024
5a937e9
remove init_weights, streamline reset_params
lessw2020 Apr 5, 2024
54cf135
updated TP norm msg, reset 7b steps
lessw2020 Apr 5, 2024
239b71f
update kernel name to TritonFusedRMSNorm
lessw2020 Apr 5, 2024
1d84539
Merge branch 'main' into fused_rms_norm
lessw2020 Apr 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ max-line-length = 120
# N812 ignored because import torch.nn.functional as F is PyTorch convention
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
Expand Down
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

torchrun --nproc_per_node=${NGPU} --rdzv_endpoint="localhost:5972" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
6 changes: 6 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def __init__(self):
default="debugmodel",
help="which model config to train",
)
self.parser.add_argument(
"--model.norm_type",
type=str,
default="rmsnorm",
help="Layer Normalization type to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
type=str,
Expand Down
67 changes: 13 additions & 54 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtrain.models.norms import create_norm


@dataclass
Expand All @@ -25,57 +26,7 @@ class ModelArgs:
depth_init: bool = (
True # initialization uses each unique layer_id or total model layer count
)


class RMSNorm(torch.nn.Module):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""

def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim))
self.reset_parameters()

def _norm(self, x: torch.Tensor):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x.float()).type_as(x)
return output * self.weight

def reset_parameters(self):
torch.nn.init.ones_(self.weight)
norm_type: str = "rmsnorm"


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
Expand Down Expand Up @@ -375,8 +326,13 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
)
self.layer_id = layer_id
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)

self.attention_norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.ffn_norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
Expand Down Expand Up @@ -441,7 +397,10 @@ def __init__(self, model_args: ModelArgs):
for layer_id in range(model_args.n_layers):
self.layers.append(TransformerBlock(layer_id, model_args))

self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.norm = create_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

Expand Down
Loading
Loading