Skip to content

Commit

Permalink
added unit test for TE layers with comm overlap
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Aug 7, 2024
1 parent e8a9940 commit 51519eb
Show file tree
Hide file tree
Showing 4 changed files with 443 additions and 17 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _parse_args(argv=None, namespace=None):
if opts.atomic:
if not te.fp8.check_fp8_support():
assert not opts.fp8, "Atomic GEMM is only supported in FP8."
opts.fp8 = True
opts.fp8 = True

return opts

Expand Down
353 changes: 353 additions & 0 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
#!/usr/bin/python3

# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import sys
import socket
import argparse
import warnings
from functools import partial

import torch
import torch.distributed as dist

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)


def _te_layer_argtype(name):
te_layers = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers))
if name.lower() not in layer_map.keys():
raise argparse.ArgumentTypeError(
f"Invalid TE layer name! Please choose from: {layer_map.keys()}"
)
return layer_map[name.lower()]


def _get_layer_args(config, tp_group, tp_size, reference=False):
hidden_size = config.num_heads * config.head_dim
input_shape = [config.seq_length, config.batch_size, hidden_size]
args = [hidden_size]
kwargs = {
"params_dtype": torch.float32,
"device": "cuda",
"tp_group": tp_group,
"tp_size": tp_size,
"sequence_parallel": True,
}
kwargs["ub_overlap_ag"] = not reference

if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = "row"
kwargs["ub_overlap_rs"] = not reference
kwargs["ub_name"] = "proj"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
kwargs["ub_name"] = "qkv"
else:
kwargs["set_parallel_mode"] = True
kwargs["ub_overlap_rs"] = not reference
if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]:
args.append(4 * hidden_size)
kwargs["seq_length"] = config.seq_length
if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args.append(config.num_heads)
kwargs["attention_dropout"] = 0.0
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True
else:
kwargs["ub_tp_comm_overlap"] = not reference
kwargs["hidden_dropout"] = 0.0

return args, kwargs, input_shape


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers."
)
parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP)
parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
parser.add_argument(
"-n", "--num-heads", type=int, default=12, help="Number of attention heads."
)
parser.add_argument(
"-d", "--head-dim", type=int, default=64, help="Dimension of each attention head."
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
)
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--tcp-init",
action="store_true",
default=False,
help="Initialize torch.distributed with TcpStore.",
)
parser.add_argument(
"--bind-to-device",
action="store_true",
default=False,
help="Initialize torch.distributed with `device_id` to bind each rank to a single device.",
)
parser.add_argument(
"--bootstrap-backend",
type=str.lower,
default="nccl",
choices=["gloo", "mpi", "nccl"],
help="Communications backend for host tensor collectives during Userbuffers bootstrapping.",
)
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs."
)
parser.add_argument(
"--debug",
action="store_true",
default=False,
help="Print out additional debug information.",
)
args = parser.parse_args(argv, namespace)

if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!")
args.use_cuda_graphs = False

return args


def _compare_tensors(name, test, ref, rtol, atol):
# Make sure tensors aren't zero and we don't pass trivially
if test.count_nonzero() == 0:
if ref.count_nonzero() == 0:
warnings.warn(
f"WARNING: {name} is a zero-tensor for both test and reference models!",
category=RuntimeWarning
)
else:
numerics_info = (
f"NUMERICAL CHECK FAILED: {name} is a zero-tensor but does not match reference!"
)
return 1, numerics_info

diff = torch.abs(test - ref).flatten()
m = torch.argmax(diff)
abs_err = diff[m].item()
rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5)
numerics_failed = 0
if rel_err > rtol and abs_err > atol:
numerics_failed = 1
numerics_info = (
"NUMERICAL CHECK FAILED: "
+ f"{name} not close enough at index {m.item()} "
+ f"with {test.flatten()[m].item()} vs {ref.flatten()[m].item()} | "
+ f"rel. error = {rel_err} (tol = {rtol}) | "
+ f"abs. error = {abs_err} (tol = {atol})"
)
else:
numerics_info = f"NUMERICAL CHECK PASSED: {name} | "
if rel_err <= rtol:
numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + (
" | " if abs_err <= atol else "."
)
if abs_err <= atol:
numerics_info += f" abs. error = {abs_err} (tol = {atol})"

return numerics_failed, numerics_info


def _train(opts):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
# Execution with `mpirun -np N`
WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0"))
WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert LOCAL_SIZE == WORLD_SIZE

def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
return
stream = sys.stderr if error else sys.stdout
if WORLD_RANK == (0 if src is None else src):
stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
dist.barrier()

# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
if opts.tcp_init:
MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname()))
MASTER_PORT = os.getenv("MASTER_PORT", "1234")
dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}"
if opts.bind_to_device or opts.bootstrap_backend == "nccl":
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")

# Intialize userbuffers
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)

# Initialize the Transformer Engine layer with overlap
args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE)
with te.fp8_model_init(enabled=opts.fp8_init):
test_model = opts.layer_type(*args, **kwargs)
dist_print("Initialized test model...", debug=True)

# Initialize the reference model and copy all parameters
ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True)
with te.fp8_model_init(enabled=opts.fp8_init):
ref_model = opts.layer_type(*ref_args, **ref_kwargs)
dist_print("Initialized reference model...", debug=True)
for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()):
with torch.no_grad():
ref_param.copy_(test_param)
torch.testing.assert_close(test_param, ref_param, rtol=0.0, atol=0.0)
dist_print("Copied parameters from test model to reference model...", debug=True)

# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Prepare random input tensors
test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True)
test_x.retain_grad()
ref_x = torch.empty_like(test_x).requires_grad_(True)
with torch.no_grad():
ref_x.copy_(test_x)
torch.testing.assert_close(test_x, ref_x, rtol=0.0, atol=0.0)
ref_x.retain_grad()

# Execute fwd/bwd and collect tensors to test
def run_fwd_bwd(model, x):
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
else:
out = y
loss = out.sum()
loss.backward()
return out

torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
test_out = run_fwd_bwd(test_model, test_x)
test_graph.replay()
del test_graph
else:
test_out = run_fwd_bwd(test_model, test_x)
test_grads = [test_out, test_x.grad]
names = ["output", "input.grad"]
for test_name, test_param in test_model.named_parameters():
if test_param.requires_grad and "layer_norm" not in test_name:
test_grads.append(test_param.grad)
names.append(test_name + ".grad")

torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_graph.replay()
del ref_graph
else:
ref_out = run_fwd_bwd(ref_model, ref_x)
ref_grads = [ref_out, ref_x.grad]
for ref_name, ref_param in ref_model.named_parameters():
if ref_param.requires_grad and "layer_norm" not in ref_name:
ref_grads.append(ref_param.grad)

# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)


# Now validate accuracy
if not bool(numerics_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break

te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)

dist_print("Destroying all process groups...", debug=True)
dist.destroy_process_group()
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)

return numerics_failed[0].item()


if __name__ == "__main__":
sys.exit(_train(_parse_args()))
Loading

0 comments on commit 51519eb

Please sign in to comment.