From cd93a28cb1446319af5e2f4b416174c3a8e43546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 23 May 2024 00:31:20 +0200 Subject: [PATCH] CUDA: fix FA out-of-bounds reads (#7479) --- ggml-cuda/fattn-tile-f16.cu | 2 +- ggml-cuda/fattn-tile-f32.cu | 2 +- ggml-cuda/fattn-vec-f16.cu | 6 +++--- ggml-cuda/fattn-vec-f32.cu | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 586d469c049d1..cdb5eaff79535 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); } } diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index b6ef8eb48d992..5a3de2918c7a3 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x]; + float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; } diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 7352dcabf6291..808e8f36246a7 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); } } @@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ic0 + j_VKQ >= ne01) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { break; } @@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } #else diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index 11476a6c0fbbc..b4652301b87e0 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i]; + Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i0/WARP_SIZE].x *= scale; Q_h2[j][i0/WARP_SIZE].y *= scale; } @@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ic0 + j_VKQ >= ne01) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { break; } @@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32( dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; } - if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) { + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); } }