Skip to content

Commit

Permalink
mtgpu: disable flash attention on qy1 (MTT S80)
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaodong Ye <[email protected]>
  • Loading branch information
yeahdongcn committed Sep 18, 2024
1 parent 30724dc commit abbc1f2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
5 changes: 5 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,12 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
{"-fa", "--flash-attn"},
format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
[](gpt_params & params) {
#ifdef FLASH_ATTN_AVAILABLE
params.flash_attn = true;
#else
GGML_UNUSED(params);
fprintf(stderr, "warning: flash attention is not supported\n");
#endif // FLASH_ATTN_AVAILABLE
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(llama_arg(
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
#define CC_QY1 210
#define CC_QY2 220

#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses

Expand Down Expand Up @@ -134,6 +136,10 @@ typedef float2 dfloat2;
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING

#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)

static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
}
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml-cuda/fattn-tile-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}

//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.

const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
Expand Down

0 comments on commit abbc1f2

Please sign in to comment.