Skip to content

Commit

Permalink
Improve FlashAttention V2 performance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Jul 27, 2024
1 parent 5f3c8c9 commit c0ac3d0
Showing 1 changed file with 70 additions and 42 deletions.
112 changes: 70 additions & 42 deletions TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void flash_attention_2_forward_kernel(
const float INFINITY = 9999999999.9f;
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
int bz = blockIdx.z; // Tr index
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
Expand All @@ -51,13 +52,16 @@ void flash_attention_2_forward_kernel(
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
float* KVj = &sram[tile_size];
// float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 2];
for (int i = q_start_offset; i < Tr; ++i) {
int i = bz;
if (i >= q_start_offset && i < Tr)
{
if (i * Br + tx >= N)
break; // break if we are done with the sequence
return; // break if we are done with the sequence
// Load Qi from HBM to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Expand All @@ -71,8 +75,8 @@ void flash_attention_2_forward_kernel(
__syncthreads();
// Load Kj, Vj from HBM to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
KVj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
// Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
// S_i^j = softmax_scale * QiKj^T
Expand All @@ -85,9 +89,9 @@ void flash_attention_2_forward_kernel(
break;
float sum = 0;
for (int x = 0; x < d; x++)
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
sum += Qi[(tx * d) + x] * KVj[(y * d) + x];
sum *= softmax_scale;
S[y] = sum;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
Expand All @@ -104,14 +108,21 @@ void flash_attention_2_forward_kernel(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
S[y] = __expf(S[y] - new_row_m);
row_l += S[y];
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - new_row_m);
row_l += S[(Bc * tx) + y];
}
// l_i^j = (exp(m_i^j-1 - m_i^j) * l_i^j-1) + row_sum(P_i^j)
float row_m_exp = __expf(row_m_prev - new_row_m);
float new_row_l = (row_m_exp * row_l_prev) + row_l;
__syncthreads();
for (int x = 0; x < d; x++) {
// Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
KVj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads();
// O_i^j = diag(exp(m_i^j-1 - m_i^j))^-1 * O_i^j-1 + P_i^jVj
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
Expand All @@ -120,7 +131,7 @@ void flash_attention_2_forward_kernel(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
pv += S[y] * Vj[(y * d) + x];
pv += S[(Bc * tx) + y] * KVj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = \
row_m_exp * O[qkv_offset + (tile_size * i) + (tx * d) + x] + pv;
Expand Down Expand Up @@ -161,6 +172,7 @@ void flash_attention_2_backward_kernel(
const float INFINITY = 9999999999.9f;
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
int bz = blockIdx.z; // Tc index;
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
Expand All @@ -183,7 +195,11 @@ void flash_attention_2_backward_kernel(
float* S = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* dS = &sram[col_tile_size * 2 + row_tile_size * 2 + Bc * Br];
for (int j = 0; j < Tc; j++) {
// for (int j = 0; j < Tc; j++) {
int j = bz;
if (j < Tc) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Expand All @@ -198,6 +214,7 @@ void flash_attention_2_backward_kernel(
float Di = 0;
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (row_tile_size * i) + (tx * d) + x];
//Oi[(tx * d) + x] = O[qkv_offset + (row_tile_size * i) + (tx * d) + x];
dOi[(tx * d) + x] = dO[qkv_offset + (row_tile_size * i) + (tx * d) + x];
Di += dOi[(tx * d) + x] * O[qkv_offset + (row_tile_size * i) + (tx * d) + x];
}
Expand Down Expand Up @@ -240,7 +257,6 @@ void flash_attention_2_backward_kernel(
// dSij <- Pij * (dPij - Di)
// dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx])
for (int y = 0; y < Bc; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
Expand Down Expand Up @@ -1372,6 +1388,7 @@ void flash_attention_2_forward_kernelHalf(
const float INFINITY = 65500.0f;
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
int bz = blockIdx.z; // Tr index
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
Expand All @@ -1381,13 +1398,17 @@ void flash_attention_2_forward_kernelHalf(
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
float* KVj = &sram[tile_size];
// float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 2];
//for (int i = q_start_offset; i < Tr; ++i) {
for (int i = q_start_offset; i < Tr; ++i) {
int i = bz;
if (i >= q_start_offset && i < Tr)
{
if (i * Br + tx >= N)
break; // break if we are done with the sequence
return; // break if we are done with the sequence
// Load Qi from HBM to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Expand All @@ -1399,10 +1420,10 @@ void flash_attention_2_forward_kernelHalf(
// Causal mask: j <= i
for (int j = 0; j <= i; ++j) {
__syncthreads();
// Load Kj, Vj from HBM to SRAM
// Load Kj from HBM to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = __half2float(K[qkv_offset + (tile_size * j) + (tx * d) + x]);
Vj[(tx * d) + x] = __half2float(V[qkv_offset + (tile_size * j) + (tx * d) + x]);
KVj[(tx * d) + x] = __half2float(K[qkv_offset + (tile_size * j) + (tx * d) + x]);
// Vj[(tx * d) + x] = __half2float(V[qkv_offset + (tile_size * j) + (tx * d) + x]);
}
__syncthreads();
// S_i^j = softmax_scale * QiKj^T
Expand All @@ -1415,9 +1436,9 @@ void flash_attention_2_forward_kernelHalf(
break;
float sum = 0;
for (int x = 0; x < d; x++)
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
sum += Qi[(tx * d) + x] * KVj[(y * d) + x];
sum *= softmax_scale;
S[y] = sum;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
Expand All @@ -1434,14 +1455,21 @@ void flash_attention_2_forward_kernelHalf(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
S[y] = __expf(S[y] - new_row_m);
row_l += S[y];
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - new_row_m);
row_l += S[(Bc * tx) + y];
}
// l_i^j = (exp(m_i^j-1 - m_i^j) * l_i^j-1) + row_sum(P_i^j)
float row_m_exp = __expf(row_m_prev - new_row_m);
float new_row_l = (row_m_exp * row_l_prev) + row_l;
__syncthreads();
// Load Vj from HBM to SRAM
for (int x = 0; x < d; x++) {
KVj[(tx * d) + x] = __half2float(V[qkv_offset + (tile_size * j) + (tx * d) + x]);
}
__syncthreads();
// O_i^j = diag(exp(m_i^j-1 - m_i^j))^-1 * O_i^j-1 + P_i^jVj
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
Expand All @@ -1450,7 +1478,7 @@ void flash_attention_2_forward_kernelHalf(
break; // break if we are done with the sequence
if (i * Br + tx < j * Bc + y)
break;
pv += S[y] * Vj[(y * d) + x];
pv += S[(Bc * tx) + y] * KVj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = \
__float2half(row_m_exp * __half2float(O[qkv_offset + (tile_size * i) + (tx * d) + x]) + pv);
Expand Down Expand Up @@ -1494,6 +1522,7 @@ void flash_attention_2_backward_kernelHalf(
const float INFINITY = 65500.0f;
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
int bz = blockIdx.z; // Tc index
// Offset into Q,K,V,O - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
Expand All @@ -1516,7 +1545,9 @@ void flash_attention_2_backward_kernelHalf(
float* S = &sram[col_tile_size * 2 + row_tile_size * 2];
//float* dS = &sram[col_tile_size * 2 + row_tile_size * 2 + Bc * Br];
for (int j = 0; j < Tc; j++) {
//for (int j = 0; j < Tc; j++) {
int j = bz;
if (j < Tc) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Expand Down Expand Up @@ -2512,7 +2543,7 @@ private void FlashAttention(TSCudaContext context, Tensor Q, Tensor K, Tensor V,
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = Math.Min(N, 512);
int Br = 32;
while (Br > 1)
{
if (N % Br == 0)
Expand All @@ -2529,14 +2560,14 @@ private void FlashAttention(TSCudaContext context, Tensor Q, Tensor K, Tensor V,
int startTr = q_start_offset / Br;

// Calculate SRAM size needed per block
int col_tile_size = Bc * d; // size of Kj, Vj
int col_tile_size = Bc * d; // size of KVj
int row_tile_size = Br * d; // size of Qi
int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj
(col_tile_size * sizeof(float)) // SRAM size for KVj
+ (row_tile_size * sizeof(float)) // SRAM size for Qi
+ (Bc * sizeof(float)); // SRAM size for S

dim3 grid = new dim3(B, nh);
+ (Bc * Br * sizeof(float)); // SRAM size for S
dim3 grid = new dim3(B, nh, Tr);
dim3 block = new dim3(Br);


Expand Down Expand Up @@ -2576,15 +2607,12 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
CudaContext cudaContext = context.CudaContextForTensor(O);
cudaContext.SetCurrent();

//int Br = 1;
//int Bc = 1;

int B = (int)Q.Sizes[0];
int nh = (int)Q.Sizes[1];
int N = (int)Q.Sizes[2];
int d = (int)Q.Sizes[3];

int Br = Math.Min(N, 128);
int Br = 32;
while (Br > 1)
{
if (N % Br == 0)
Expand All @@ -2608,7 +2636,7 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
+ (2 * row_tile_size * sizeof(float)) // SRAM size for Qi, dOi
+ (Br * Bc * sizeof(float)); // SRAM size for S

dim3 grid = new dim3(B, nh);
dim3 grid = new dim3(B, nh, Tc);
dim3 block = new dim3(Br);


Expand All @@ -2629,8 +2657,8 @@ private void FlashAttentionGrad(TSCudaContext context, Tensor Q, Tensor K, Tenso
}

Invoke(context, cudaContext, kernelName, grid, block, (uint)sram_size,
CUstream.NullStream, QPtr, KPtr, VPtr, OPtr, dOPtr, LPtr, N, d, Tc, Tr, Bc, Br, softmax_scale,
dQPtr, dKPtr, dVPtr);
CUstream.NullStream, QPtr, KPtr, VPtr, OPtr, dOPtr, LPtr, N, d, Tc, Tr, Bc, Br, softmax_scale,
dQPtr, dKPtr, dVPtr);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -3558,4 +3586,4 @@ private void Invoke(TSCudaContext context, CudaContext cudaContext, string kerne
kernel.RunAsync(stream, args);
}
}
}
}

0 comments on commit c0ac3d0

Please sign in to comment.