From 3e06fcac0994dee80c9ac039846e671ad4699e38 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 12 Mar 2024 12:49:17 -0400 Subject: [PATCH] llama : fix Mamba inference for pipeline parallelism --- llama.cpp | 134 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 80 insertions(+), 54 deletions(-) diff --git a/llama.cpp b/llama.cpp index f307160cb0598b..2d716f11c7ad9e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2082,7 +2082,7 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, kv_size] struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch] #ifdef GGML_USE_MPI @@ -5518,6 +5518,9 @@ struct llm_build_context { lctx.inp_K_shift = nullptr; lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; + lctx.inp_s_copy = nullptr; + lctx.inp_s_mask = nullptr; + lctx.inp_s_seq = nullptr; } void free() { @@ -5559,14 +5562,14 @@ struct llm_build_context { GGML_ASSERT(kv_self.recurrent); - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + struct ggml_tensor * state_copy = build_inp_s_copy(); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy); + conv_states = ggml_get_rows(ctx0, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); // TODO: name the intermediate tensors with cb() @@ -5665,6 +5668,27 @@ struct llm_build_context { return lctx.inp_cls; } + struct ggml_tensor * build_inp_s_copy() { + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + cb(lctx.inp_s_copy, "inp_s_copy", -1); + ggml_set_input(lctx.inp_s_copy); + return lctx.inp_s_copy; + } + + struct ggml_tensor * build_inp_s_mask() { + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + cb(lctx.inp_s_mask, "inp_s_mask", -1); + ggml_set_input(lctx.inp_s_mask); + return lctx.inp_s_mask; + } + + struct ggml_tensor * build_inp_s_seq() { + lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + cb(lctx.inp_s_seq, "inp_s_seq", -1); + ggml_set_input(lctx.inp_s_seq); + return lctx.inp_s_seq; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8148,12 +8172,8 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); - lctx.inp_s_mask = state_mask; - lctx.inp_s_seq = state_seq; - ggml_set_input(state_mask); - ggml_set_input(state_seq); + struct ggml_tensor * state_mask = build_inp_s_mask(); + struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { // (ab)using the KV cache to store the states @@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); } - if (batch.pos) { + if (batch.pos && lctx.inp_pos) { const int64_t n_tokens = batch.n_tokens; ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); @@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { "non-causal attention with generative models is not supported" ); - // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn) { - const int64_t n_kv = kv_self.n; - const int64_t n_tokens = batch.n_tokens; + if (lctx.inp_KQ_mask) { + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn) { + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = batch.n_tokens; - assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = (float *) lctx.inp_KQ_mask->data; - // For causal attention, use only the previous KV cells - // of the correct sequence for each token of the batch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the batch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_kv; ++i) { - float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { - f = -INFINITY; - } else { - f = 0.0f; + for (int i = 0; i < n_kv; ++i) { + float f; + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + f = 0.0f; + } + data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } - data[h*(n_kv*n_tokens) + j*n_kv + i] = f; } } - } - } else { - // when using kv cache, the mask needs to match the kv cache size - const int64_t n_tokens = batch.n_tokens; - const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; + } else { + // when using kv cache, the mask needs to match the kv cache size + const int64_t n_tokens = batch.n_tokens; + const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens; - assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = (float *) lctx.inp_KQ_mask->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_tokens; ++i) { - float f = -INFINITY; - for (int s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id[i][s] == seq_id) { - f = 0.0f; - break; + for (int i = 0; i < n_tokens; ++i) { + float f = -INFINITY; + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id[i][s] == seq_id) { + f = 0.0f; + break; + } } - } - data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; - } + data[h*(n_tokens*n_tokens) + j*n_stride + i] = f; + } - for (int i = n_tokens; i < n_stride; ++i) { - data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY; + } } } } @@ -9272,11 +9294,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - llama_set_s_copy(lctx); - { + ggml_backend_sched_reset(lctx.sched); + ggml_cgraph * gf = llama_build_graph_s_copy(lctx); + ggml_backend_sched_alloc_graph(lctx.sched, gf); + + llama_set_s_copy(lctx); + llama_graph_compute(lctx, gf, lctx.cparams.n_threads); need_reserve = true;