Skip to content

Commit

Permalink
mtl : add reshape and transport handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 31, 2023
1 parent 1213af7 commit 9883af8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions examples/mtl/mtl.m
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ int llama_mtl_eval(

switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
{
// noop
} break;
Expand Down
8 changes: 6 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -15011,15 +15011,19 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **

// create the tensor
// "view" operations are handled differently
// TODO: handle inplac ops - currentl a copy is always made

struct ggml_tensor * tensor = NULL;

switch (eop) {
// TODO: implement other view ops
case GGML_OP_RESHAPE:
{
// TODO: implement other dims
tensor = ggml_reshape_3d(*ctx_eval, args[0], ne[0], ne[1], ne[2]);
tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
} break;
case GGML_OP_TRANSPOSE:
{
tensor = ggml_transpose(*ctx_eval, args[0]);
} break;
default:
{
Expand Down
9 changes: 4 additions & 5 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,15 +1279,14 @@ static bool llama_eval_internal(
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");

// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Qcur, "mtl-check");
}

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Vcur, "mtl-check");
}

struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
Expand Down

0 comments on commit 9883af8

Please sign in to comment.