Skip to content

Commit

Permalink
Improve the performance of FlashAttentionV2
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Jul 21, 2024
1 parent 07f3a33 commit 5f3c8c9
Showing 1 changed file with 38 additions and 45 deletions.
83 changes: 38 additions & 45 deletions TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ void flash_attention_2_forward_kernel(
for (int x = 0; x < d; x++)
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
S[y] = sum;
if (sum > row_m)
row_m = sum;
Expand All @@ -104,8 +104,8 @@ void flash_attention_2_forward_kernel(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - new_row_m);
row_l += S[(Bc * tx) + y];
S[y] = __expf(S[y] - new_row_m);
row_l += S[y];
}
// l_i^j = (exp(m_i^j-1 - m_i^j) * l_i^j-1) + row_sum(P_i^j)
Expand All @@ -120,7 +120,7 @@ void flash_attention_2_forward_kernel(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
pv += S[y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = \
row_m_exp * O[qkv_offset + (tile_size * i) + (tx * d) + x] + pv;
Expand Down Expand Up @@ -174,14 +174,14 @@ void flash_attention_2_backward_kernel(
float* Vj = &sram[col_tile_size];
float* Qi = &sram[col_tile_size * 2];
float* Oi = &sram[col_tile_size * 2 + row_tile_size];
float* dOi = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* Oi = &sram[col_tile_size * 2 + row_tile_size];
float* dOi = &sram[col_tile_size * 2 + row_tile_size];
// We also use S for P. Likewise, we use dS for dP.
// We can reuse the same memory because we don't need S and P at the same time.
// We also don't need dS and dP at the same time.
float* S = &sram[col_tile_size * 2 + row_tile_size * 3];
float* dS = &sram[col_tile_size * 2 + row_tile_size * 3 + Bc * Br];
float* S = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* dS = &sram[col_tile_size * 2 + row_tile_size * 2 + Bc * Br];
for (int j = 0; j < Tc; j++) {
Expand All @@ -198,9 +198,8 @@ void flash_attention_2_backward_kernel(
float Di = 0;
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Oi[(tx * d) + x] = O[qkv_offset + (row_tile_size * i) + (tx * d) + x];
dOi[(tx * d) + x] = dO[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Di += dOi[(tx * d) + x] * Oi[(tx * d) + x];
Di += dOi[(tx * d) + x] * O[qkv_offset + (row_tile_size * i) + (tx * d) + x];
}
float l_curr = L[lm_offset + (Br * i) + tx];
Expand Down Expand Up @@ -238,26 +237,24 @@ void flash_attention_2_backward_kernel(
// dPij <- dOi * Vj^T
// dPij[tx][y] = Sum_{x = 0}^{d-1} dOi[tx][x] * Vj[y][x]
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += dOi[(tx * d) + x] * Vj[(y * d) + x];
}
dS[(Bc * tx) + y] = sum;
}
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; ++y) {
dS[(Bc * tx) + y] = S[(Bc * tx) + y] * (dS[(Bc * tx) + y] - Di);
S[(Bc * tx) + y] = S[(Bc * tx) + y] * (sum - Di);
}
// dQi <- dQi + softmax_scale * dSijKj
// dQ[tx][x] = dQ[tx][x] + softmax_scale * Sum_{y = 0}^{Bc-1} dSij[tx][y] * Kj[y][x]
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Bc; y++) {
sum += dS[(Bc * tx) + y] * Kj[(y * d) + x];
sum += S[(Bc * tx) + y] * Kj[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + (tx * d) + x], sum);
Expand All @@ -268,7 +265,7 @@ void flash_attention_2_backward_kernel(
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Br; y++) {
sum += dS[(Bc * y) + tx] * Qi[(y * d) + x];
sum += S[(Bc * y) + tx] * Qi[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dK[qkv_offset + (row_tile_size * j) + (tx * d) + x], sum);
Expand Down Expand Up @@ -1420,7 +1417,7 @@ void flash_attention_2_forward_kernelHalf(
for (int x = 0; x < d; x++)
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
sum *= softmax_scale;
S[(Bc * tx) + y] = sum;
S[y] = sum;
if (sum > row_m)
row_m = sum;
Expand All @@ -1437,8 +1434,8 @@ void flash_attention_2_forward_kernelHalf(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - new_row_m);
row_l += S[(Bc * tx) + y];
S[y] = __expf(S[y] - new_row_m);
row_l += S[y];
}
// l_i^j = (exp(m_i^j-1 - m_i^j) * l_i^j-1) + row_sum(P_i^j)
Expand All @@ -1453,7 +1450,7 @@ void flash_attention_2_forward_kernelHalf(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
pv += S[y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = \
__float2half(row_m_exp * __half2float(O[qkv_offset + (tile_size * i) + (tx * d) + x]) + pv);
Expand Down Expand Up @@ -1510,14 +1507,14 @@ void flash_attention_2_backward_kernelHalf(
float* Vj = &sram[col_tile_size];
float* Qi = &sram[col_tile_size * 2];
float* Oi = &sram[col_tile_size * 2 + row_tile_size];
float* dOi = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* Oi = &sram[col_tile_size * 2 + row_tile_size];
float* dOi = &sram[col_tile_size * 2 + row_tile_size];
// We also use S for P. Likewise, we use dS for dP.
// We can reuse the same memory because we don't need S and P at the same time.
// We also don't need dS and dP at the same time.
float* S = &sram[col_tile_size * 2 + row_tile_size * 3];
float* dS = &sram[col_tile_size * 2 + row_tile_size * 3 + Bc * Br];
float* S = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* dS = &sram[col_tile_size * 2 + row_tile_size * 2 + Bc * Br];
for (int j = 0; j < Tc; j++) {
Expand All @@ -1534,9 +1531,8 @@ void flash_attention_2_backward_kernelHalf(
float Di = 0;
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = __half2float(Q[qkv_offset + (row_tile_size * i) + (tx * d) + x]);
Oi[(tx * d) + x] = __half2float(O[qkv_offset + (row_tile_size * i) + (tx * d) + x]);
dOi[(tx * d) + x] = dO[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Di += dOi[(tx * d) + x] * Oi[(tx * d) + x];
dOi[(tx * d) + x] = __half2float(dO[qkv_offset + (row_tile_size * i) + (tx * d) + x]);
Di += dOi[(tx * d) + x] * __half2float(O[qkv_offset + (row_tile_size * i) + (tx * d) + x]);
}
float l_curr = L[lm_offset + (Br * i) + tx];
Expand Down Expand Up @@ -1574,26 +1570,23 @@ void flash_attention_2_backward_kernelHalf(
// dPij <- dOi * Vj^T
// dPij[tx][y] = Sum_{x = 0}^{d-1} dOi[tx][x] * Vj[y][x]
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += dOi[(tx * d) + x] * Vj[(y * d) + x];
}
dS[(Bc * tx) + y] = sum;
}
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; ++y) {
dS[(Bc * tx) + y] = S[(Bc * tx) + y] * (dS[(Bc * tx) + y] - Di);
S[(Bc * tx) + y] = S[(Bc * tx) + y] * (sum - Di);
}
// dQi <- dQi + softmax_scale * dSijKj
// dQ[tx][x] = dQ[tx][x] + softmax_scale * Sum_{y = 0}^{Bc-1} dSij[tx][y] * Kj[y][x]
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Bc; y++) {
sum += dS[(Bc * tx) + y] * Kj[(y * d) + x];
sum += S[(Bc * tx) + y] * Kj[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + (tx * d) + x], __float2half(sum));
Expand All @@ -1604,7 +1597,7 @@ void flash_attention_2_backward_kernelHalf(
for (int x = 0; x < d; x++) {
float sum = 0;
for (int y = 0; y < Br; y++) {
sum += dS[(Bc * y) + tx] * Qi[(y * d) + x];
sum += S[(Bc * y) + tx] * Qi[(y * d) + x];
}
sum *= softmax_scale;
atomicAdd(&dK[qkv_offset + (row_tile_size * j) + (tx * d) + x], __float2half(sum));
Expand Down Expand Up @@ -2519,7 +2512,7 @@ private void FlashAttention(TSCudaContext context, Tensor Q, Tensor K, Tensor V,
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = 32;
int Br = Math.Min(N, 512);
while (Br > 1)
{
if (N % Br == 0)
Expand All @@ -2541,7 +2534,7 @@ private void FlashAttention(TSCudaContext context, Tensor Q, Tensor K, Tensor V,
int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj
+ (row_tile_size * sizeof(float)) // SRAM size for Qi
+ (Bc * Br * sizeof(float)); // SRAM size for S
+ (Bc * sizeof(float)); // SRAM size for S

dim3 grid = new dim3(B, nh);
dim3 block = new dim3(Br);
Expand Down Expand Up @@ -2591,7 +2584,7 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = 32;
int Br = Math.Min(N, 128);
while (Br > 1)
{
if (N % Br == 0)
Expand All @@ -2609,11 +2602,11 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso

// Calculate SRAM size needed per block
int col_tile_size = Bc * d; // size of dKj, dVj
int row_tile_size = Br * d; // size of Qi, Oi, dOi
int row_tile_size = Br * d; // size of Qi, dOi
int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for dKj, dVj
+ (3 * row_tile_size * sizeof(float)) // SRAM size for Qi, Oi, dOi
+ (2 * Br * Bc * sizeof(float)); // SRAM size for S, dS
+ (2 * row_tile_size * sizeof(float)) // SRAM size for Qi, dOi
+ (Br * Bc * sizeof(float)); // SRAM size for S

dim3 grid = new dim3(B, nh);
dim3 block = new dim3(Br);
Expand Down

0 comments on commit 5f3c8c9

Please sign in to comment.