Skip to content

Commit

Permalink
mtl : add rope kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 31, 2023
1 parent 6af6a05 commit 1213af7
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 9 deletions.
75 changes: 74 additions & 1 deletion examples/mtl/mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@

id<MTLFunction> function_mul_mat_q4_0;
id<MTLComputePipelineState> pipeline_mul_mat_q4_0;

id<MTLFunction> function_rope;
id<MTLComputePipelineState> pipeline_rope;
};

// MSL code
Expand Down Expand Up @@ -148,6 +151,10 @@
ctx->function_mul_mat_q4_0 = [ctx->library newFunctionWithName:@"kernel_mul_mat_q4_0"];
ctx->pipeline_mul_mat_q4_0 = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_mat_q4_0 error:nil];
fprintf(stderr, "%s: loaded kernel_mul_mat_q4_0: %p\n", __func__, (void *) ctx->pipeline_mul_mat_q4_0);

ctx->function_rope = [ctx->library newFunctionWithName:@"kernel_rope"];
ctx->pipeline_rope = [ctx->device newComputePipelineStateWithFunction:ctx->function_rope error:nil];
fprintf(stderr, "%s: loaded kernel_rope: %p\n", __func__, (void *) ctx->pipeline_rope);
}

// MTLBuffer approach
Expand Down Expand Up @@ -250,6 +257,10 @@ int llama_mtl_eval(
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));

switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
Expand Down Expand Up @@ -453,6 +464,68 @@ int llama_mtl_eval(

[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}

id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);

const int64_t ne00 = gf->nodes[i]->src0->ne[0];
const int64_t ne01 = gf->nodes[i]->src0->ne[1];
const int64_t ne02 = gf->nodes[i]->src0->ne[2];
const int64_t ne03 = gf->nodes[i]->src0->ne[3];

const uint64_t nb00 = gf->nodes[i]->src0->nb[0];
const uint64_t nb01 = gf->nodes[i]->src0->nb[1];
const uint64_t nb02 = gf->nodes[i]->src0->nb[2];
const uint64_t nb03 = gf->nodes[i]->src0->nb[3];

const int64_t ne0 = gf->nodes[i]->ne[0];
const int64_t ne1 = gf->nodes[i]->ne[1];
const int64_t ne2 = gf->nodes[i]->ne[2];
const int64_t ne3 = gf->nodes[i]->ne[3];

const uint64_t nb0 = gf->nodes[i]->nb[0];
const uint64_t nb1 = gf->nodes[i]->nb[1];
const uint64_t nb2 = gf->nodes[i]->nb[2];
const uint64_t nb3 = gf->nodes[i]->nb[3];

const int n_past = ((int32_t *) gf->nodes[i]->src1->data)[0]; // TODO: TMP !!!!!
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];

printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);

[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
GGML_ASSERT(false);
Expand Down Expand Up @@ -486,7 +559,7 @@ int llama_mtl_eval(

{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
fprintf(stderr, "%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}

// TODO
Expand Down
55 changes: 55 additions & 0 deletions examples/mtl/mtl.metal
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,58 @@ kernel void kernel_mul_mat_q4_0(
dst[r1*ne0 + r0] = sum[0];
}
}

kernel void kernel_rope(
device const void * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int & n_past,
constant int & n_dims,
constant int & mode,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];

const bool is_neox = mode & 2;
const float theta_scale = pow(10000.0, -2.0f/n_dims);

const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);

float theta = (float)p;

if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

theta *= theta_scale;

device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

const float x0 = src[0];
const float x1 = src[1];

dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: implement
}
}
24 changes: 16 additions & 8 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1270,19 +1270,20 @@ static bool llama_eval_internal(

// self-attention
{
auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(x, "mtl-check");
}
//auto * x = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);

// compute Q and K and RoPE them
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");

// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Qcur, "mtl-check");
}

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
Expand Down Expand Up @@ -1437,7 +1438,14 @@ static bool llama_eval_internal(
//ggml_graph_compute (ctx0, &gf);

// lets export a smaller graph to get things rolling -- baby steps first
ggml_build_forward_expand(&gf_export, ggml_get_tensor(ctx0, "mtl-check"));
{
struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
if (!t) {
fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
exit(1);
}
ggml_build_forward_expand(&gf_export, t);
}

// print
{
Expand Down

0 comments on commit 1213af7

Please sign in to comment.