Skip to content

Commit

Permalink
[PyTorch] Update docs/example and benchmarks/ scripts (NVIDIA#1075)
Browse files Browse the repository at this point in the history
* update example/benchmark scripts

Signed-off-by: Charlene Yang <[email protected]>

* fix head_dim after MLA

Signed-off-by: Charlene Yang <[email protected]>

* update notebook

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and mgoldfarb-nvidia committed Aug 14, 2024
1 parent cb29a54 commit 302cc22
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 123 deletions.
20 changes: 7 additions & 13 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
_get_attention_backends,
_run_dot_product_attention,
)

Expand All @@ -29,8 +27,6 @@
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
Expand Down Expand Up @@ -64,7 +60,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -76,7 +71,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -97,7 +91,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand All @@ -115,7 +108,6 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
ckpt_attn,
qkv_layout,
workspace_opt,
swa,
pad_between_seqs,
is_training,
)
Expand Down Expand Up @@ -205,13 +197,15 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
available_backends, fused_attn_backends = _get_attention_backends(
config,
dtype,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
Expand Down
16 changes: 11 additions & 5 deletions docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention

# Initialize RNG state
Expand All @@ -22,7 +21,7 @@
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)


def _run_dot_product_attention(
Expand All @@ -40,7 +39,7 @@ def _run_dot_product_attention(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
inp = torch.randn(
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
[config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim_qk],
dtype=dtype,
device="cuda",
)
Expand All @@ -51,7 +50,7 @@ def _run_dot_product_attention(
k.requires_grad = True
v.requires_grad = True
out_grad = torch.randn(
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
[config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim_v],
dtype=dtype,
device="cuda",
)
Expand Down Expand Up @@ -80,7 +79,7 @@ def _run_dot_product_attention(

block = DotProductAttention(
config.num_heads,
config.head_dim,
config.head_dim_qk,
num_gqa_groups=config.num_gqa_groups,
qkv_format="bshd",
attention_dropout=config.dropout_p,
Expand All @@ -89,6 +88,8 @@ def _run_dot_product_attention(
get_rng_state_tracker=None,
tp_group=None,
layer_number=1,
attn_mask_type="no_mask",
window_size=(-1, -1),
).to(dtype=dtype, device="cuda")

# Run a forward and backward pass
Expand All @@ -103,6 +104,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # 'arbitrary'
core_attention_bias_type=config.attn_bias_type, # 'no_bias'
core_attention_bias=bias, # None
window_size=(-1, -1),
)
out.backward(out_grad)

Expand All @@ -116,6 +118,7 @@ def _run_dot_product_attention(
attn_mask_type=config.attn_mask_type, # no_mask
core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
core_attention_bias=bias, # bias
window_size=(-1, -1),
)
out.backward(out_grad)

Expand All @@ -133,11 +136,14 @@ def _run_dot_product_attention(
config = model_configs["test_bias"]
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")

print()
print("Run with arbitrary mask:")
config = model_configs["test_mask"]
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")

torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)

print()
print("Test passed!")
Loading

0 comments on commit 302cc22

Please sign in to comment.