Skip to content

Commit

Permalink
CUDA: fix FA out-of-bounds reads (ggerganov#7479)
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler authored and teleprint-me committed May 23, 2024
1 parent e6f0ab5 commit 94672bc
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion ggml-cuda/fattn-tile-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
Expand Down
2 changes: 1 addition & 1 deletion ggml-cuda/fattn-tile-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(

#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
}
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda/fattn-vec-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
Expand Down Expand Up @@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(

#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ic0 + j_VKQ >= ne01) {
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
break;
}

Expand All @@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}

if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
Expand Down
6 changes: 3 additions & 3 deletions ggml-cuda/fattn-vec-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;

Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i0/WARP_SIZE].x *= scale;
Q_h2[j][i0/WARP_SIZE].y *= scale;
}
Expand Down Expand Up @@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(

#pragma unroll
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
if (ic0 + j_VKQ >= ne01) {
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
break;
}

Expand All @@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}

if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
}
Expand Down

0 comments on commit 94672bc

Please sign in to comment.