Skip to content

Commit

Permalink
Add THD + GQA supports (#1260)
Browse files Browse the repository at this point in the history
Add THD + GQA supports for cuDNN >= 9.6

Signed-off-by: Reese Wang <[email protected]>
  • Loading branch information
zlsh80826 authored Oct 22, 2024
1 parent 35f7d26 commit d9b4bfb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 7 additions & 4 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,14 @@ def _check_configs(self):
]:
pytest.skip("THD format requires padding masks.")

if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.")
qkv_format = get_qkv_format(self.qkv_layout)
if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD:
if self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.")
pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv")

if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD:
if self.num_heads_q != self.num_heads_kv:
pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv")

if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None:
pytest.skip(
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
// qkv format
((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
(sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
qkv_format == NVTE_QKV_Format::NVTE_THD) ||
(qkv_format == NVTE_QKV_Format::NVTE_BSHD)) &&
((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
(cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 90600))) &&
// sliding window
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
(window_size_right == -1 || window_size_right == 0)) ||
Expand Down

0 comments on commit d9b4bfb

Please sign in to comment.