From abbc1f28ca617aa73664788cfd5b2995e27ec8d3 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Wed, 18 Sep 2024 17:53:55 +0800 Subject: [PATCH] mtgpu: disable flash attention on qy1 (MTT S80) Signed-off-by: Xiaodong Ye --- common/arg.cpp | 5 +++++ ggml/src/ggml-cuda/common.cuh | 6 ++++++ ggml/src/ggml-cuda/fattn-tile-f32.cu | 6 +++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 922391069d32aa..0c2a80f0fd7379 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -703,7 +703,12 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, {"-fa", "--flash-attn"}, format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), [](gpt_params & params) { +#ifdef FLASH_ATTN_AVAILABLE params.flash_attn = true; +#else + GGML_UNUSED(params); + fprintf(stderr, "warning: flash attention is not supported\n"); +#endif // FLASH_ATTN_AVAILABLE } ).set_env("LLAMA_ARG_FLASH_ATTN")); add_opt(llama_arg( diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eb39b6d23a6b3f..6b437d3789f16b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -50,6 +50,8 @@ #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define CC_QY1 210 +#define CC_QY2 220 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -134,6 +136,10 @@ typedef float2 dfloat2; #define INT8_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1) +#define FLASH_ATTN_AVAILABLE +#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1) + static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 827437ca0ad1ff..f402195ce0b774 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; } - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.