Skip to content

Commit

Permalink
Update FE to 1.5.2 and miscellaneous fixes (#975)
Browse files Browse the repository at this point in the history
* update FE to 1.5.2

Signed-off-by: Charlene Yang <[email protected]>

* enable unfused attn for cross attn

Signed-off-by: Charlene Yang <[email protected]>

* unify logging info

Signed-off-by: Charlene Yang <[email protected]>

* omit cudnn 9.1.1 and 9.2.1 due to bugs

Signed-off-by: Charlene Yang <[email protected]>

* set cu_seqlens_padded to cu_seqlens by default

Signed-off-by: Charlene Yang <[email protected]>

* replace variable name with ctx.variable

Signed-off-by: Charlene Yang <[email protected]>

* Revert "enable unfused attn for cross attn"

This reverts commit bc49f14.

Signed-off-by: Charlene Yang <[email protected]>

* restrict cudnn version for fp8 tests

Signed-off-by: Charlene Yang <[email protected]>

* remove mha_fill for FP8

Signed-off-by: Charlene Yang <[email protected]>

* Revert "remove mha_fill for FP8"

This reverts commit 83ffc44114dc6eb3d426d742b6c5a4d34805ec04.

Signed-off-by: Charlene Yang <[email protected]>

* lower cudnn version to >=9.2.1

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Jul 1, 2024
1 parent 7326af9 commit 67b6743
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 32 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 113 files
13 changes: 10 additions & 3 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def _rmse(a, b):
return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
return out, param_names, tuple(x.grad for x in params)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
Expand Down Expand Up @@ -1654,7 +1654,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
if cudnn_frontend_version == 0
else get_cudnn_version() < (9, 2, 1)
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8)
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
init_method_normal,
scaled_init_method_normal,
is_bf16_compatible,
get_cudnn_version,
)
from transformer_engine.pytorch import (
LayerNormLinear,
Expand Down Expand Up @@ -1004,6 +1005,7 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):

@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() != (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_sanity_attention_extra_state(model, dtype):
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
(max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
((cudnn_runtime_version >= 90100) && (max_seqlen_q % 128 == 0) &&
((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) &&
(max_seqlen_kv % 128 == 0) && (head_dim == 128) &&
((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4179,9 +4179,10 @@ def forward(
and cu_seqlens_q is not None
and cu_seqlens_kv is not None
), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!"
if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv

if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None:
cu_seqlens_q_padded = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_kv

qkv_dtype = TE_DType[query_layer.dtype]

Expand Down
24 changes: 16 additions & 8 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""GroupedLinear API"""
import os
import logging
from typing import Union, Optional, Callable, Tuple, List, Dict, Any

import torch
Expand Down Expand Up @@ -44,7 +45,16 @@
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["GroupedLinear"]

Expand Down Expand Up @@ -95,6 +105,7 @@ def forward(
is_grad_enabled: bool,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor:
logger = logging.getLogger("GroupedLinear")
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
weights_fp8 = weights_and_biases[num_gemms : 2 * num_gemms]
Expand Down Expand Up @@ -149,8 +160,7 @@ def forward(
inputmats = inputmats_no_fp8

if fp8:
if _NVTE_DEBUG:
print("[GroupedLinear]: using FP8 forward")
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
Expand Down Expand Up @@ -188,8 +198,7 @@ def forward(
# unpad the output
out = torch.cat([o[: m_splits[i]] for i, o in enumerate(out_list)], dim=0)
else:
if _NVTE_DEBUG:
print("[GroupedLinear]: using non-FP8 forward")
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weights = [cast_if_needed(w, activation_dtype) for w in weights]
Expand Down Expand Up @@ -294,6 +303,7 @@ def forward(

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("GroupedLinear")

with torch.cuda.nvtx.range("_GroupedLinear_backward"):
(
Expand Down Expand Up @@ -361,8 +371,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
if ctx.fp8:
if _NVTE_DEBUG:
print("[GroupedLinear]: using FP8 backward")
logger.debug("Running backward in FP8")
dgrad_list = [
torch.empty(
(grad_output_c[i].size(0), weights_fp8[i].size(1)),
Expand Down Expand Up @@ -392,8 +401,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
[d[: ctx.m_splits[i]] for i, d in enumerate(dgrad_list)], dim=0
)
else:
if _NVTE_DEBUG:
print("[GroupedLinear]: using non-FP8 backward")
logger.debug("Running backward in %s", ctx.activation_dtype)

dgrad = torch.empty(
(sum(ctx.m_splits), weights[0].size(1)),
Expand Down
24 changes: 16 additions & 8 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -47,7 +48,16 @@
from ._common import _apply_normalization, _noop_cat
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["LayerNormLinear"]

Expand Down Expand Up @@ -94,6 +104,7 @@ def forward(
ub_name: str,
fsdp_group: Union[dist_group_type, None],
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
logger = logging.getLogger("LayerNormLinear")
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.shape[-1] == in_features, "GEMM not possible"
Expand Down Expand Up @@ -190,8 +201,7 @@ def forward(
ln_out = ln_out_total

if fp8:
if _NVTE_DEBUG:
print("[LayerNormLinear]: using FP8 forward")
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
Expand Down Expand Up @@ -247,8 +257,7 @@ def forward(
dtype=activation_dtype,
)
else:
if _NVTE_DEBUG:
print("[LayerNormLinear]: using non-FP8 forward")
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
Expand Down Expand Up @@ -370,6 +379,7 @@ def forward(
def backward(
ctx, *grad_outputs: Tuple[torch.Tensor, ...]
) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("LayerNormLinear")
if isinstance(grad_outputs[0], Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[
0
Expand Down Expand Up @@ -490,8 +500,7 @@ def backward(
ub_obj = None

if ctx.fp8:
if _NVTE_DEBUG:
print("[LayerNormLinear]: using FP8 backward")
logger.debug("Running backward in FP8")

fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
Expand Down Expand Up @@ -535,8 +544,7 @@ def backward(
)
clear_tensor_data(grad_output_c)
else:
if _NVTE_DEBUG:
print("[LayerNormLinear]: using non-FP8 backward")
logger.debug("Running backward in %s", ctx.activation_dtype)

# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = tex.gemm(
Expand Down
24 changes: 16 additions & 8 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Linear API"""
import os
import logging
from typing import Any, Callable, Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -50,7 +51,16 @@
from ..graph import is_graph_capturing
from ..float8_tensor import Float8Tensor

# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0
_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0"))
log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL
log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}
logging.basicConfig(
format="[%(levelname)-8s | %(name)-19s]: %(message)s",
level=log_levels[log_level if log_level in [0, 1, 2] else 2],
)

__all__ = ["Linear"]

Expand Down Expand Up @@ -87,6 +97,7 @@ def forward(
is_first_module_in_mha: bool,
fsdp_group: Union[dist_group_type, None],
) -> torch.Tensor:
logger = logging.getLogger("Linear")
is_input_fp8 = isinstance(inp, Float8Tensor)
if is_input_fp8:
fp8_meta["scaling_fwd"].scale_inv[tex.FP8FwdTensors.GEMM1_INPUT] = inp._scale_inv[0]
Expand Down Expand Up @@ -147,8 +158,7 @@ def forward(
else:
inputmat_total = inputmat
if fp8:
if _NVTE_DEBUG:
print("[Linear]: using FP8 forward")
logger.debug("Running forward in FP8")

bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
bias = cast_if_needed(bias, bias_dtype) if use_bias else bias
Expand Down Expand Up @@ -238,8 +248,7 @@ def forward(
dtype=activation_dtype,
)
else:
if _NVTE_DEBUG:
print("[Linear]: using non-FP8 forward")
logger.debug("Running forward in %s", activation_dtype)

# Cast for native AMP
weight = cast_if_needed(weight, activation_dtype)
Expand Down Expand Up @@ -366,6 +375,7 @@ def forward(

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
logger = logging.getLogger("Linear")
if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1
Expand Down Expand Up @@ -442,8 +452,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
if ctx.fp8:
if _NVTE_DEBUG:
print("[Linear]: using FP8 backward")
logger.debug("Running backward in FP8")

if ctx.is_input_fp8:
out_index, meta_tensor, output_te_dtype, output_dtype = (
Expand Down Expand Up @@ -487,8 +496,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dtype=ctx.activation_dtype,
)
else:
if _NVTE_DEBUG:
print("[Linear]: using non-FP8 backward")
logger.debug("Running backward in %s", ctx.activation_dtype)

dgrad, _, _ = gemm(
weight,
Expand Down

0 comments on commit 67b6743

Please sign in to comment.