diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d4f92e940d..a48a38e6a9 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -305,11 +305,14 @@ def _check_configs(self): ]: pytest.skip("THD format requires padding masks.") - if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD: - if self.num_heads_q != self.num_heads_kv: - pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.") + qkv_format = get_qkv_format(self.qkv_layout) + if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: if self.max_seqlen_q != self.max_seqlen_kv: - pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.") + pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") + + if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: + if self.num_heads_q != self.num_heads_kv: + pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None: pytest.skip( diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ff8b66e82f..afa8a9c58d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -181,10 +181,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || - (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && - qkv_format == NVTE_QKV_Format::NVTE_THD) || - (qkv_format == NVTE_QKV_Format::NVTE_BSHD)) && + ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && + (cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 90600))) && // sliding window ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||