diff --git a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs index 78552fb..f086bd5 100644 --- a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs +++ b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs @@ -41,6 +41,8 @@ void flash_attention_2_forward_kernel( ) { const float INFINITY = 9999999999.9f; int tx = threadIdx.x; + int txd = tx * d; + int bx = blockIdx.x; int by = blockIdx.y; // batch and head index int bz = blockIdx.z; // Tr index @@ -64,8 +66,9 @@ void flash_attention_2_forward_kernel( return; // break if we are done with the sequence // Load Qi from HBM to SRAM, l and m to registers + for (int x = 0; x < d; x++) { - Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; + Qi[txd + x] = Q[qkv_offset + (tile_size * i) + txd + x]; } float row_m_prev = -INFINITY; float row_l_prev = 0; @@ -75,8 +78,8 @@ void flash_attention_2_forward_kernel( __syncthreads(); // Load Kj, Vj from HBM to SRAM for (int x = 0; x < d; x++) { - KVj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; - // Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; + KVj[txd + x] = K[qkv_offset + (tile_size * j) + txd + x]; + // Vj[txd + x] = V[qkv_offset + (tile_size * j) + txd + x]; } __syncthreads(); // S_i^j = softmax_scale * QiKj^T @@ -88,8 +91,25 @@ void flash_attention_2_forward_kernel( if (i * Br + tx < j * Bc + y) break; float sum = 0; - for (int x = 0; x < d; x++) - sum += Qi[(tx * d) + x] * KVj[(y * d) + x]; + + if (d == 128) + { +#pragma unroll 128 + for (int x = 0; x < 128; x++) + sum += Qi[txd + x] * KVj[(y * 128) + x]; + + } + else if (d == 64) + { +#pragma unroll 64 + for (int x = 0; x < 64; x++) + sum += Qi[txd + x] * KVj[(y * 64) + x]; + } + else + { + for (int x = 0; x < d; x++) + sum += Qi[txd + x] * KVj[(y * d) + x]; + } sum *= softmax_scale; S[(Bc * tx) + y] = sum; @@ -118,8 +138,8 @@ void flash_attention_2_forward_kernel( __syncthreads(); for (int x = 0; x < d; x++) { -// Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; - KVj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; +// Kj[txd + x] = K[qkv_offset + (tile_size * j) + txd + x]; + KVj[txd + x] = V[qkv_offset + (tile_size * j) + txd + x]; } __syncthreads(); @@ -133,8 +153,8 @@ void flash_attention_2_forward_kernel( break; pv += S[(Bc * tx) + y] * KVj[(y * d) + x]; } - O[qkv_offset + (tile_size * i) + (tx * d) + x] = \ - row_m_exp * O[qkv_offset + (tile_size * i) + (tx * d) + x] + pv; + O[qkv_offset + (tile_size * i) + txd + x] = \ + row_m_exp * O[qkv_offset + (tile_size * i) + txd + x] + pv; } // Update m and l @@ -144,7 +164,7 @@ void flash_attention_2_forward_kernel( // O_i = diag(l_i^{Tc})^-1 * O_i^{Tc} for (int x = 0; x < d; x++) - O[qkv_offset + (tile_size * i) + (tx * d) + x] /= row_l_prev; + O[qkv_offset + (tile_size * i) + txd + x] /= row_l_prev; // L_i = m_i^{Tc} + log(l_i^{Tc}) L[lm_offset + (Br * i) + tx] = row_m_prev + __logf(row_l_prev); } @@ -171,6 +191,7 @@ void flash_attention_2_backward_kernel( ) { const float INFINITY = 9999999999.9f; int tx = threadIdx.x; + int txd = tx * d; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index int bz = blockIdx.z; // Tc index; @@ -203,8 +224,8 @@ void flash_attention_2_backward_kernel( // Load Kj, Vj to SRAM for (int x = 0; x < d; x++) { - Kj[(tx * d) + x] = K[qkv_offset + (col_tile_size * j) + (tx * d) + x]; - Vj[(tx * d) + x] = V[qkv_offset + (col_tile_size * j) + (tx * d) + x]; + Kj[txd + x] = K[qkv_offset + (col_tile_size * j) + txd + x]; + Vj[txd + x] = V[qkv_offset + (col_tile_size * j) + txd + x]; } for (int i = j; i < Tr; i++) { @@ -213,43 +234,68 @@ void flash_attention_2_backward_kernel( // Also load l, m to registers float Di = 0; for (int x = 0; x < d; x++) { - Qi[(tx * d) + x] = Q[qkv_offset + (row_tile_size * i) + (tx * d) + x]; - //Oi[(tx * d) + x] = O[qkv_offset + (row_tile_size * i) + (tx * d) + x]; - dOi[(tx * d) + x] = dO[qkv_offset + (row_tile_size * i) + (tx * d) + x]; - Di += dOi[(tx * d) + x] * O[qkv_offset + (row_tile_size * i) + (tx * d) + x]; + Qi[txd + x] = Q[qkv_offset + (row_tile_size * i) + txd + x]; + //Oi[txd + x] = O[qkv_offset + (row_tile_size * i) + txd + x]; + dOi[txd + x] = dO[qkv_offset + (row_tile_size * i) + txd + x]; + Di += dOi[txd + x] * O[qkv_offset + (row_tile_size * i) + txd + x]; } float l_curr = L[lm_offset + (Br * i) + tx]; // Sij = softmax_scale * QiKj^T // Sij[tx][y] = softmax_scale * Sum_{y = 0}^{Bc-1} Qi[tx][x] * Kj[y][x] - for (int y = 0; y < Bc; y++) { - float sum = 0; - for (int x = 0; x < d; x++) { - sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; - } - sum *= softmax_scale; - if (i * Br + tx < j * Bc + y) - sum = -INFINITY; - S[(Bc * tx) + y] = sum; - } // Pij = diag(li)^-1 * exp(Sij - mi) // Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx]) + for (int y = 0; y < Bc; y++) { + float sum = 0; + + if (d == 128) + { +#pragma unroll 128 + for (int x = 0; x < 128; x++) { + sum += Qi[txd + x] * Kj[(y * 128) + x]; + } + } + else if (d == 64) + { +#pragma unroll 64 + for (int x = 0; x < 64; x++) { + sum += Qi[txd + x] * Kj[(y * 64) + x]; + } + } + else + { + for (int x = 0; x < d; x++) { + sum += Qi[txd + x] * Kj[(y * d) + x]; + } + } + + sum *= softmax_scale; if (i * Br + tx < j * Bc + y) S[(Bc * tx) + y] = 0; else - S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - l_curr); + S[(Bc * tx) + y] = __expf(sum - l_curr); } + + //// Pij = diag(li)^-1 * exp(Sij - mi) + //// Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx]) + //for (int y = 0; y < Bc; y++) { + // if (i * Br + tx < j * Bc + y) + // S[(Bc * tx) + y] = 0; + // else + // S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - l_curr); + //} __syncthreads(); // dVj <- dVj + Pij^T * dOi // dVj[tx][x] = dVj[tx][x] + Sum_{y = 0}^{Br-1} Pij[y][tx] * dOi[tx][x] for (int x = 0; x < d; x++) { float sum = 0; + float dOi_x = dOi[txd + x]; for (int y = 0; y < Br; y++) { - sum += S[(Bc * y) + tx] * dOi[(tx * d) + x]; + sum += S[(Bc * y) + tx] * dOi_x; } - atomicAdd(&dV[qkv_offset + (row_tile_size * j) + (tx * d) + x], sum); + atomicAdd(&dV[qkv_offset + (row_tile_size * j) + txd + x], sum); } // dPij <- dOi * Vj^T @@ -259,9 +305,28 @@ void flash_attention_2_backward_kernel( // dSij[tx][y] = Pij[tx][y] * (dPij[tx][y] - Di[tx]) for (int y = 0; y < Bc; y++) { float sum = 0; - for (int x = 0; x < d; x++) { - sum += dOi[(tx * d) + x] * Vj[(y * d) + x]; + + if (d == 128) + { +#pragma unroll 128 + for (int x = 0; x < 128; x++) { + sum += dOi[txd + x] * Vj[(y * 128) + x]; + } + } + else if (d == 64) + { +#pragma unroll 64 + for (int x = 0; x < 64; x++) { + sum += dOi[txd + x] * Vj[(y * 64) + x]; + } } + else + { + for (int x = 0; x < d; x++) { + sum += dOi[txd + x] * Vj[(y * d) + x]; + } + } + S[(Bc * tx) + y] = S[(Bc * tx) + y] * (sum - Di); } @@ -273,7 +338,7 @@ void flash_attention_2_backward_kernel( sum += S[(Bc * tx) + y] * Kj[(y * d) + x]; } sum *= softmax_scale; - atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + (tx * d) + x], sum); + atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + txd + x], sum); } __syncthreads(); // dKj <- dKj + softmax_scale * dSij^TQi @@ -284,7 +349,7 @@ void flash_attention_2_backward_kernel( sum += S[(Bc * y) + tx] * Qi[(y * d) + x]; } sum *= softmax_scale; - atomicAdd(&dK[qkv_offset + (row_tile_size * j) + (tx * d) + x], sum); + atomicAdd(&dK[qkv_offset + (row_tile_size * j) + txd + x], sum); } } } @@ -1387,6 +1452,7 @@ void flash_attention_2_forward_kernelHalf( ) { const float INFINITY = 65500.0f; int tx = threadIdx.x; + int txd = tx * d; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index int bz = blockIdx.z; // Tr index @@ -1412,7 +1478,7 @@ void flash_attention_2_forward_kernelHalf( // Load Qi from HBM to SRAM, l and m to registers for (int x = 0; x < d; x++) { - Qi[(tx * d) + x] = __half2float(Q[qkv_offset + (tile_size * i) + (tx * d) + x]); + Qi[txd + x] = __half2float(Q[qkv_offset + (tile_size * i) + txd + x]); } float row_m_prev = -INFINITY; float row_l_prev = 0; @@ -1422,8 +1488,8 @@ void flash_attention_2_forward_kernelHalf( __syncthreads(); // Load Kj from HBM to SRAM for (int x = 0; x < d; x++) { - KVj[(tx * d) + x] = __half2float(K[qkv_offset + (tile_size * j) + (tx * d) + x]); - // Vj[(tx * d) + x] = __half2float(V[qkv_offset + (tile_size * j) + (tx * d) + x]); + KVj[txd + x] = __half2float(K[qkv_offset + (tile_size * j) + txd + x]); + // Vj[txd + x] = __half2float(V[qkv_offset + (tile_size * j) + txd + x]); } __syncthreads(); // S_i^j = softmax_scale * QiKj^T @@ -1436,7 +1502,7 @@ void flash_attention_2_forward_kernelHalf( break; float sum = 0; for (int x = 0; x < d; x++) - sum += Qi[(tx * d) + x] * KVj[(y * d) + x]; + sum += Qi[txd + x] * KVj[(y * d) + x]; sum *= softmax_scale; S[(Bc * tx) + y] = sum; @@ -1466,7 +1532,7 @@ void flash_attention_2_forward_kernelHalf( __syncthreads(); // Load Vj from HBM to SRAM for (int x = 0; x < d; x++) { - KVj[(tx * d) + x] = __half2float(V[qkv_offset + (tile_size * j) + (tx * d) + x]); + KVj[txd + x] = __half2float(V[qkv_offset + (tile_size * j) + txd + x]); } __syncthreads(); @@ -1480,8 +1546,8 @@ void flash_attention_2_forward_kernelHalf( break; pv += S[(Bc * tx) + y] * KVj[(y * d) + x]; } - O[qkv_offset + (tile_size * i) + (tx * d) + x] = \ - __float2half(row_m_exp * __half2float(O[qkv_offset + (tile_size * i) + (tx * d) + x]) + pv); + O[qkv_offset + (tile_size * i) + txd + x] = \ + __float2half(row_m_exp * __half2float(O[qkv_offset + (tile_size * i) + txd + x]) + pv); } // Update m and l @@ -1491,8 +1557,8 @@ void flash_attention_2_forward_kernelHalf( // O_i = diag(l_i^{Tc})^-1 * O_i^{Tc} for (int x = 0; x < d; x++) - //O[qkv_offset + (tile_size * i) + (tx * d) + x] /= row_l_prev; - O[qkv_offset + (tile_size * i) + (tx * d) + x] = __float2half(__half2float(O[qkv_offset + (tile_size * i) + (tx * d) + x]) / row_l_prev); + //O[qkv_offset + (tile_size * i) + txd + x] /= row_l_prev; + O[qkv_offset + (tile_size * i) + txd + x] = __float2half(__half2float(O[qkv_offset + (tile_size * i) + txd + x]) / row_l_prev); // L_i = m_i^{Tc} + log(l_i^{Tc}) @@ -1521,6 +1587,7 @@ void flash_attention_2_backward_kernelHalf( ) { const float INFINITY = 65500.0f; int tx = threadIdx.x; + int txd = tx * d; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index int bz = blockIdx.z; // Tc index @@ -1551,8 +1618,8 @@ void flash_attention_2_backward_kernelHalf( // Load Kj, Vj to SRAM for (int x = 0; x < d; x++) { - Kj[(tx * d) + x] = __half2float(K[qkv_offset + (col_tile_size * j) + (tx * d) + x]); - Vj[(tx * d) + x] = __half2float(V[qkv_offset + (col_tile_size * j) + (tx * d) + x]); + Kj[txd + x] = __half2float(K[qkv_offset + (col_tile_size * j) + txd + x]); + Vj[txd + x] = __half2float(V[qkv_offset + (col_tile_size * j) + txd + x]); } for (int i = j; i < Tr; i++) { @@ -1561,42 +1628,48 @@ void flash_attention_2_backward_kernelHalf( // Also load l, m to registers float Di = 0; for (int x = 0; x < d; x++) { - Qi[(tx * d) + x] = __half2float(Q[qkv_offset + (row_tile_size * i) + (tx * d) + x]); - dOi[(tx * d) + x] = __half2float(dO[qkv_offset + (row_tile_size * i) + (tx * d) + x]); - Di += dOi[(tx * d) + x] * __half2float(O[qkv_offset + (row_tile_size * i) + (tx * d) + x]); + Qi[txd + x] = __half2float(Q[qkv_offset + (row_tile_size * i) + txd + x]); + dOi[txd + x] = __half2float(dO[qkv_offset + (row_tile_size * i) + txd + x]); + Di += dOi[txd + x] * __half2float(O[qkv_offset + (row_tile_size * i) + txd + x]); } float l_curr = L[lm_offset + (Br * i) + tx]; // Sij = softmax_scale * QiKj^T // Sij[tx][y] = softmax_scale * Sum_{y = 0}^{Bc-1} Qi[tx][x] * Kj[y][x] + + // Pij = diag(li)^-1 * exp(Sij - mi) + // Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx]) + for (int y = 0; y < Bc; y++) { float sum = 0; for (int x = 0; x < d; x++) { - sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; + sum += Qi[txd + x] * Kj[(y * d) + x]; } sum *= softmax_scale; - if (i * Br + tx < j * Bc + y) - sum = -INFINITY; - S[(Bc * tx) + y] = sum; - } - - // Pij = diag(li)^-1 * exp(Sij - mi) - // Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx]) - for (int y = 0; y < Bc; y++) { if (i * Br + tx < j * Bc + y) S[(Bc * tx) + y] = 0; else - S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - l_curr); + S[(Bc * tx) + y] = __expf(sum - l_curr); } + + //// Pij = diag(li)^-1 * exp(Sij - mi) + //// Pij[tx][y] = (1 / li[tx]) * exp(Sij[tx][y] - mi[tx]) + //for (int y = 0; y < Bc; y++) { + // if (i * Br + tx < j * Bc + y) + // S[(Bc * tx) + y] = 0; + // else + // S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - l_curr); + //} __syncthreads(); // dVj <- dVj + Pij^T * dOi // dVj[tx][x] = dVj[tx][x] + Sum_{y = 0}^{Br-1} Pij[y][tx] * dOi[tx][x] for (int x = 0; x < d; x++) { float sum = 0; + float dOi_x = dOi[txd + x]; for (int y = 0; y < Br; y++) { - sum += S[(Bc * y) + tx] * dOi[(tx * d) + x]; + sum += S[(Bc * y) + tx] * dOi_x; } - atomicAdd(&dV[qkv_offset + (row_tile_size * j) + (tx * d) + x], __float2half(sum)); + atomicAdd(&dV[qkv_offset + (row_tile_size * j) + txd + x], __float2half(sum)); } // dPij <- dOi * Vj^T @@ -1607,7 +1680,7 @@ void flash_attention_2_backward_kernelHalf( for (int y = 0; y < Bc; y++) { float sum = 0; for (int x = 0; x < d; x++) { - sum += dOi[(tx * d) + x] * Vj[(y * d) + x]; + sum += dOi[txd + x] * Vj[(y * d) + x]; } S[(Bc * tx) + y] = S[(Bc * tx) + y] * (sum - Di); } @@ -1620,7 +1693,7 @@ void flash_attention_2_backward_kernelHalf( sum += S[(Bc * tx) + y] * Kj[(y * d) + x]; } sum *= softmax_scale; - atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + (tx * d) + x], __float2half(sum)); + atomicAdd(&dQ[qkv_offset + (row_tile_size * i) + txd + x], __float2half(sum)); } __syncthreads(); // dKj <- dKj + softmax_scale * dSij^TQi @@ -1631,7 +1704,7 @@ void flash_attention_2_backward_kernelHalf( sum += S[(Bc * y) + tx] * Qi[(y * d) + x]; } sum *= softmax_scale; - atomicAdd(&dK[qkv_offset + (row_tile_size * j) + (tx * d) + x], __float2half(sum)); + atomicAdd(&dK[qkv_offset + (row_tile_size * j) + txd + x], __float2half(sum)); } } }