Skip to content

Commit

Permalink
metal : improve clarity (minor) (ggerganov#10171)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 8, 2024
1 parent 841f27a commit 695ad75
Showing 1 changed file with 45 additions and 31 deletions.
76 changes: 45 additions & 31 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NW4 = NW/4;
const short NL = NW/4;
const short SH = 2*C; // shared memory per simdgroup

const short T = D + nsg*SH; // shared memory size per query in (half)
Expand All @@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results

// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o4x4_t lo[D16/NW4];
o4x4_t lo[D16/NL];

// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
Expand All @@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
}

// zero out lo
for (short i = 0; i < D16/NW4; i += NW4) {
for (short i = 0; i < D16/NL; ++i) {
lo[i] = (o4x4_t) 0.0f;
}

Expand All @@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
half M = -__FLT16_MAX__/2;

// thread indices inside the simdgroup
const short tx = tiisg%8;
const short ty = tiisg/8;
const short tx = tiisg%NL;
const short ty = tiisg/NL;

// broadcast kv
//const short rk2 = ne02/ne12;
Expand All @@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
const short ikv3 = iq3/(ne03/ne_12_3);

// load the queries from shared memory into local memory
q4x4_t mq[D16/NW4];
q4x4_t mq[D16/NL];

for (short ii = 0; ii < D16; ii += NW4) {
mq[ii/NW4] = sq4x4[ii + tx];
for (short ii = 0; ii < D16; ii += NL) {
mq[ii/NL] = sq4x4[ii + tx];
}

const bool has_mask = mask != q;
Expand Down Expand Up @@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));

#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx;

k4x4_t mk;
deq_k(pk + i/nl_k, i%nl_k, mk);

mqk +=
dot(mq[ii/NW4][0], mk[0]) +
dot(mq[ii/NW4][1], mk[1]) +
dot(mq[ii/NW4][2], mk[2]) +
dot(mq[ii/NW4][3], mk[3]);
dot(mq[ii/NL][0], mk[0]) +
dot(mq[ii/NL][1], mk[1]) +
dot(mq[ii/NL][2], mk[2]) +
dot(mq[ii/NL][3], mk[3]);
}

// simdgroup reduce
Expand Down Expand Up @@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(

// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW4] *= ms;
for (short ii = 0; ii < D16; ii += NL) {
lo[ii/NL] *= ms;
}
}

Expand All @@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
const s4x4_t ms(ss[4*cc + ty]);

#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
for (short ii = 0; ii < D16; ii += NL) {
const short i = ii + tx;

v4x4_t mv;
deq_v(pv4 + i/nl_v, i%nl_v, mv);

lo[ii/NW4] += mv*ms;
lo[ii/NL] += mv*ms;
}
}
}
Expand All @@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
// [ 5, 13, 21, 29] -> [ 5]
// [ 6, 14, 22, 30] -> [ 6]
// [ 7, 15, 23, 31] -> [ 7]
for (short ii = 0; ii < D16; ii += NW4) {
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);

lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);

lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);

lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
for (short ii = 0; ii < D16; ii += NL) {
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);

lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);

lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);

lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// store results to shared memory
for (short i = tiisg; i < D16; i += NW4) {
sr4x4[i] = lo[i/NW4];
for (short i = tiisg; i < D16; i += NL) {
sr4x4[i] = lo[i/NL];
}

threadgroup_barrier(mem_flags::mem_threadgroup);
Expand Down

0 comments on commit 695ad75

Please sign in to comment.