Skip to content

Commit

Permalink
[PyTorch] Merge k_channels and v_channels back to kv_channels (N…
Browse files Browse the repository at this point in the history
…VIDIA#1094)

* merge k_channels and v_channels back to kv_channels and accept a tuple

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix isinstance call

Signed-off-by: Charlene Yang <[email protected]>

* fix MLA tests

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and mgoldfarb-nvidia committed Aug 14, 2024
1 parent 8e2f4a2 commit cb29a54
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
# Set up model
block = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def test_export_core_attention(

model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
k_channels=kv_channels,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
Expand Down
32 changes: 22 additions & 10 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5177,10 +5177,9 @@ class DotProductAttention(TransformerEngineBaseModule):
----------
num_attention_heads : int
number of attention heads in the transformer layer.
k_channels : int
number of channels per attention head in key.
v_channels : Optional[int] = None
number of channels per attention head in value.
kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
Expand Down Expand Up @@ -5242,7 +5241,7 @@ class DotProductAttention(TransformerEngineBaseModule):
For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`.
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
Parallelism parameters
----------------------
Expand All @@ -5266,8 +5265,7 @@ class DotProductAttention(TransformerEngineBaseModule):
def __init__(
self,
num_attention_heads: int,
k_channels: int,
v_channels: Optional[int] = None,
kv_channels: Union[int, Tuple[int, int]],
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
Expand Down Expand Up @@ -5310,8 +5308,12 @@ def __init__(
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream

self.hidden_size_per_attention_head = k_channels
self.v_channels = k_channels if v_channels is None else v_channels
self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
self.hidden_size_per_attention_head_v = (
kv_channels if isinstance(kv_channels, int) else kv_channels[1]
)

self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
Expand All @@ -5329,7 +5331,9 @@ def __init__(
attention_dropout_ctx = self.rng_states_tracker.fork

if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(k_channels)
softmax_scale = 1.0 / math.sqrt(
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)

self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
Expand Down Expand Up @@ -5628,6 +5632,14 @@ def forward(
assert (
key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!"
assert (
key_layer.shape[-1] == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {key_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_k}!"
assert (
value_layer.shape[-1] == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {value_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"

if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
Expand Down

0 comments on commit cb29a54

Please sign in to comment.