From feb0966af193adfe63d33bd0227a0bde093ba7b3 Mon Sep 17 00:00:00 2001 From: okada Date: Sat, 16 Dec 2023 15:55:58 +0900 Subject: [PATCH 01/18] add plamo mock --- llama.cpp | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index d6d575f9e3960..3cfe8d8d9b6f7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -195,6 +195,7 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, + LLM_ARCH_PLAMO, LLM_ARCH_UNKNOWN, }; @@ -212,6 +213,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_PLAMO, "plamo" }, }; enum llm_kv { @@ -550,7 +552,24 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, - + { + LLM_ARCH_PLAMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2635,6 +2654,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_PLAMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_13B; break; //TODO Check + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3630,7 +3658,10 @@ static void llm_load_tensors( } } } break; - + case LLM_ARCH_PLAMO: + { + //TODO + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5424,6 +5455,122 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_plamo() { + //TODO + /* + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + */ + } }; // @@ -5922,6 +6069,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; + case LLM_ARCH_PLAMO: + { + result = llm.build_plamo(); + } break; default: GGML_ASSERT(false); } From 4c585b4c6c244952583afa339e64d69711df123d Mon Sep 17 00:00:00 2001 From: okada Date: Sat, 16 Dec 2023 16:24:54 +0900 Subject: [PATCH 02/18] add tensor loading --- llama.cpp | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 3cfe8d8d9b6f7..baf15863be0da 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3660,7 +3660,62 @@ static void llm_load_tensors( } break; case LLM_ARCH_PLAMO: { - //TODO + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + backend_norm = llama_backend_offload; + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } + } } break; default: throw std::runtime_error("unknown architecture"); From b2330f57e29426b044f454b937364a778ceb7ef7 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 15:23:59 +0900 Subject: [PATCH 03/18] plamo convert --- convert-hf-to-gguf.py | 67 +++++++++++++++++++++++++++++++++- gguf-py/gguf/constants.py | 17 +++++++++ gguf-py/gguf/tensor_mapping.py | 37 ++++++++++++------- 3 files changed, 106 insertions(+), 15 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e46a7813a78e9..6c783d79674bc 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -182,6 +182,8 @@ def from_model_architecture(model_architecture): return QwenModel if model_architecture == "MixtralForCausalLM": return MixtralModel + if model_architecture == "PlamoForCausalLM": + return PlamoModel return Model def _is_model_safetensors(self) -> bool: @@ -221,6 +223,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: return gguf.MODEL_ARCH.QWEN if arch == "MixtralForCausalLM": return gguf.MODEL_ARCH.LLAMA + if arch == "PlamoForCausalLM": + return gguf.MODEL_ARCH.PLAMO raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -980,11 +984,72 @@ def write_tensors(self): print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") self.gguf_writer.add_tensor(new_name, data) + +class PlamoModel(Model): + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + hparams = self.hparams + block_count = hparams["num_hidden_layers"] + + self.gguf_writer.add_name("PLaMo") + self.gguf_writer.add_context_length(4096) # not in config.json + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + + def write_tensors(self): + block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for name, data_torch in self.get_tensors(): + if "self_attn.rotary_emb.inv_freq" in name: + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file") + parser = argparse.ArgumentParser( + description="Convert a huggingface model to a GGML compatible file") parser.add_argument( "--vocab-only", action="store_true", help="extract only the vocab", diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 12133882be2c4..5de0e0d8b1030 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -95,6 +95,7 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() + PLAMO = auto() class MODEL_TENSOR(IntEnum): @@ -140,6 +141,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.PLAMO: "plamo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -347,6 +349,21 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PLAMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.GPT2: [ # TODO ], diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 0115ea1c605b1..36d3860654783 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -75,6 +75,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi + "model.layers.layers.{bid}.norm", # plamo ), # Attention norm 2 @@ -94,26 +95,29 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf - "layers.{bid}.attention.wq", # llama-pth - "encoder.layer.{bid}.attention.self.query", # bert - "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo ), # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf - "layers.{bid}.attention.wk", # llama-pth - "encoder.layer.{bid}.attention.self.key", # bert - "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.k_proj", # plamo ), # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf - "layers.{bid}.attention.wv", # llama-pth - "encoder.layer.{bid}.attention.self.value", # bert - "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.v_proj", # plamo ), # Attention output @@ -128,12 +132,14 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.layers.{bid}.self_attn.o_proj", # plamo ), # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( - "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf - "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo ), # Feed-forward norm @@ -167,6 +173,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.fc_in", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "transformer.h.{bid}.mlp.w1", # qwen + "model.layers.layers.{bid}.mlp.up_proj", # plamo ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -179,6 +186,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate_proj", # llama-hf refact "layers.{bid}.feed_forward.w1", # llama-pth "transformer.h.{bid}.mlp.w2", # qwen + "model.layers.layers.{bid}.mlp.gate_proj", # plamo ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -198,6 +206,7 @@ class TensorNameMap: "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "model.layers.layers.{bid}.mlp.down_proj", # plamo ), MODEL_TENSOR.FFN_DOWN_EXP: ( From 9d49236570b992eea99400bac7bbaaf69ca1ce2d Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 15:44:59 +0900 Subject: [PATCH 04/18] update norm --- llama.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index baf15863be0da..4075fc34984dd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5611,11 +5611,13 @@ struct llm_build_context { cur = inpL; - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, - model.output_norm_b, - LLM_NORM, cb, -1); - cb(cur, "result_norm", -1); + // norm + { + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + } // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); From 4a3ef4f2a4327e0f0c59b2e8a37115c88f974207 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 17:44:29 +0900 Subject: [PATCH 05/18] able to compile --- llama.cpp | 202 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 144 insertions(+), 58 deletions(-) diff --git a/llama.cpp b/llama.cpp index 4075fc34984dd..bbcfbd9950ea9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3705,15 +3705,15 @@ static void llm_load_tensors( layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); - layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); - layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_down); } } } break; @@ -5512,20 +5512,11 @@ struct llm_build_context { } struct ggml_cgraph * build_plamo() { - //TODO - /* struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * cur; struct ggml_tensor * inpL; - inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); - cb(inpL, "inp_embd", -1); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); @@ -5534,76 +5525,172 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(KQ_pos, "KQ_pos", -1); + // shift the entire K-cache if needed if (do_rope_shift) { - llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); } - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + //for (int il = 0; il < n_layer; ++il) { + for (int il = 0; il < 1; ++il) { // norm cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, cb, il); - cb(cur, "attn_norm", il); + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attention_norm_0", il); + + struct ggml_tensor * attention_norm = cur; // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(tmpk, "tmpk", il); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(tmpq, "tmpq", il); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + struct ggml_tensor * Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, + n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Kcur, "Kcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + struct ggml_tensor * Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, + n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Kcur, "Kcur", il); + // store key and value to memory + { + // compute the transposed [n_tokens, n_embd] V matrix - llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(tmpv, "tmpv", il); - cur = llm_build_kqv(ctx0, hparams, kv_self, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); - cb(cur, "kqv_out", il); - } + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); + cb(Vcur, "Vcur", il); - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + //struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*(il*n_ctx + kv_head)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*kv_head); + cb(k, "k", il); + + /* + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + */ + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_gqa, + n_ctx*ggml_element_size(kv_self.v_l[il]), + kv_head*ggml_element_size(kv_self.v_l[il])); + cb(v, "v", il); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + cb(Q, "Q", il); + + /* + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + */ + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k_l[il])*n_embd_gqa, + ggml_element_size(kv_self.k_l[il])*n_embd_head, + 0); + cb(K, "K", il); + + // K * Q + //struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att + struct ggml_tensor * K_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K->ne[0], K->ne[1], Q->ne[2]); + cb(K_repeated, "K_repeated", il); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, ggml_repeat(ctx0, K, K_repeated), Q); + cb(KQ, "KQ", il); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_kv, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + cb(KQ_scaled, "KQ_scaled", il); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + cb(KQ_masked, "KQ_masked", il); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + cb(KQ_soft_max, "KQ_soft_max", il); + + // split cached V into n_head heads + /* + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + */ + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head, + 0); + cb(V, "V", il); + + //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att + struct ggml_tensor * V_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, V->ne[0], V->ne[1], Q->ne[2]); + cb(V_repeated, "V_repeated", il); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_repeat(ctx0, V, V_repeated), KQ_soft_max); + cb(KQV, "KQV", il); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + cb(KQV_merged, "KQV_merged", il); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + cb(cur, "KQV_merged_contiguous", il); + + // projection (no bias) + cur = ggml_mul_mat(ctx0, + model.layers[il].wo, + cur); + cb(cur, "result_wo", il); + } + struct ggml_tensor * sa_out = cur; + + cur = attention_norm; // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_gate, NULL, model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); + cb(cur, "mlp_out", il); } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); + cur = ggml_add(ctx0, cur, sa_out); + cb(cur, "mlp_out + sa_out", il); + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "mlp_out + sa_out + inpL", il); // input for next layer inpL = cur; @@ -5626,7 +5713,6 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); return gf; - */ } }; From a22040a810735a12881faca8a982d2656c1266a2 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 18:15:25 +0900 Subject: [PATCH 06/18] fix norm_rms_eps hparam --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index bbcfbd9950ea9..e0ff7e10cbe4b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2656,10 +2656,10 @@ static void llm_load_hparams( } break; case LLM_ARCH_PLAMO: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_13B; break; //TODO Check + case 40: model.type = e_model::MODEL_13B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; From 86d5348fd0ada7a404c76c92746c3a2cf66eb596 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 18:29:08 +0900 Subject: [PATCH 07/18] runnable --- llama.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama.cpp b/llama.cpp index e0ff7e10cbe4b..c3b34c2c36717 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5517,6 +5517,9 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); From f76fd39266abfffc840077d4129cdd91b51f29ca Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 21:53:04 +0900 Subject: [PATCH 08/18] use inp_pos --- llama.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index c3b34c2c36717..44b4b727c8e3e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5528,9 +5528,9 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(KQ_pos, "KQ_pos", -1); + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); // shift the entire K-cache if needed if (do_rope_shift) { @@ -5558,13 +5558,13 @@ struct llm_build_context { cb(tmpq, "tmpq", il); struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); From ca8f698638e1801b6f4f2bdf70e9f2bbacc107c9 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 17 Dec 2023 23:28:29 +0900 Subject: [PATCH 09/18] seems ok --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 44b4b727c8e3e..c54688d52dd2c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5537,8 +5537,8 @@ struct llm_build_context { llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); } - //for (int il = 0; il < n_layer; ++il) { - for (int il = 0; il < 1; ++il) { + for (int il = 0; il < n_layer; ++il) { + //for (int il = 0; il < 1; ++il) { // norm cur = llm_build_norm(ctx0, inpL, hparams, From febc63598b349076fba050147f4f3e6ebf11923a Mon Sep 17 00:00:00 2001 From: okada Date: Mon, 18 Dec 2023 00:16:56 +0900 Subject: [PATCH 10/18] update kqv code --- llama.cpp | 191 ++++++++++++++++++++++-------------------------------- 1 file changed, 79 insertions(+), 112 deletions(-) diff --git a/llama.cpp b/llama.cpp index c54688d52dd2c..4eef1317c8b22 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5520,6 +5520,10 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); cb(inpL, "inp_embd", -1); + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); cb(KQ_scale, "KQ_scale", -1); @@ -5528,10 +5532,6 @@ struct llm_build_context { struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - // shift the entire K-cache if needed if (do_rope_shift) { llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); @@ -5544,137 +5544,104 @@ struct llm_build_context { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attention_norm_0", il); + cb(cur, "attention_norm", il); struct ggml_tensor * attention_norm = cur; // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(tmpk, "tmpk", il); + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(tmpq, "tmpq", il); + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); - struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cb(Kcur, "Kcur", il); + cb(Qcur, "Qcur", il); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); - cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); - // store key and value to memory - { - // compute the transposed [n_tokens, n_embd] V matrix + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(tmpv, "tmpv", il); + auto plamo_llm_build_kqv = []( + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_kv_cache & kv, + struct ggml_tensor * wo, + struct ggml_tensor * q_cur, + struct ggml_tensor * kq_mask, + int64_t n_ctx, + int32_t n_tokens, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); + cb(q, "q", il); + + struct ggml_tensor * k = + ggml_view_3d(ctx, kv.k_l[il], + n_embd_head, n_kv, n_head_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head), + 0); + cb(k, "k", il); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); - cb(Vcur, "Vcur", il); + // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att + struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); + cb(k_repeated, "k_repeated", il); - //struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*(il*n_ctx + kv_head)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k_l[il], n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k_l[il])*n_embd_gqa)*kv_head); - cb(k, "k", il); + struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); + cb(kq, "kq_soft_max_ext", il); - /* - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); - */ - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v_l[il], n_tokens, n_embd_gqa, - n_ctx*ggml_element_size(kv_self.v_l[il]), - kv_head*ggml_element_size(kv_self.v_l[il])); + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv.v_l[il])*n_ctx, + ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head, + 0); cb(v, "v", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } + // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att + struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); + cb(k_repeated, "v_repeated", il); - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - cb(Q, "Q", il); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); - /* - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); - */ - struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k_l[il])*n_embd_gqa, - ggml_element_size(kv_self.k_l[il])*n_embd_head, - 0); - cb(K, "K", il); - - // K * Q - //struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * K_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K->ne[0], K->ne[1], Q->ne[2]); - cb(K_repeated, "K_repeated", il); - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, ggml_repeat(ctx0, K, K_repeated), Q); - cb(KQ, "KQ", il); - - // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_kv, n_tokens, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - cb(KQ_scaled, "KQ_scaled", il); - - // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - cb(KQ_masked, "KQ_masked", il); - - // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - cb(KQ_soft_max, "KQ_soft_max", il); - - // split cached V into n_head heads - /* - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); - */ - struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v_l[il], - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v_l[il])*n_ctx, - ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head, - 0); - cb(V, "V", il); - - //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * V_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, V->ne[0], V->ne[1], Q->ne[2]); - cb(V_repeated, "V_repeated", il); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_repeat(ctx0, V, V_repeated), KQ_soft_max); - cb(KQV, "KQV", il); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - cb(KQV_merged, "KQV_merged", il); - - // cur = KQV_merged.contiguous().view(n_embd, n_tokens) - cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); - cb(cur, "KQV_merged_contiguous", il); - - // projection (no bias) - cur = ggml_mul_mat(ctx0, + struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); + cb(cur, "kqv_merged_cont", il); + + cur = ggml_mul_mat(ctx, wo, cur); + return cur; + }; + + cur = plamo_llm_build_kqv(ctx0, hparams, kv_self, model.layers[il].wo, - cur); - cb(cur, "result_wo", il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il); + cb(cur, "kqv_out", il); } struct ggml_tensor * sa_out = cur; From 907b92185cc2cd2ef9832ade61faa322addabff2 Mon Sep 17 00:00:00 2001 From: okada Date: Mon, 18 Dec 2023 16:32:16 +0900 Subject: [PATCH 11/18] remove develop code --- llama.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 4eef1317c8b22..25a1a4a7c0567 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5538,7 +5538,6 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { - //for (int il = 0; il < 1; ++il) { // norm cur = llm_build_norm(ctx0, inpL, hparams, From 9339ffc96de5c33d297ba5045b86223eec2a5abb Mon Sep 17 00:00:00 2001 From: okada Date: Mon, 18 Dec 2023 16:46:51 +0900 Subject: [PATCH 12/18] update README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 01aef2afc36ae..ee152d5a042f2 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github - [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek) - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral) +- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) **Multimodal models:** From db1b18dc9786f9fb17264878acbcd69ab35545ce Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 24 Dec 2023 17:58:55 +0900 Subject: [PATCH 13/18] shuffle attn_q.weight and attn_output.weight for broadcasting --- convert-hf-to-gguf.py | 20 ++++++++++++++++++++ llama.cpp | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6c783d79674bc..689285fd6be72 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1002,6 +1002,20 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + def shuffle_attn_q_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(8, 5, 128, 5120) + data_torch = torch.permute(data_torch, (1, 0, 2, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + + def shuffle_attn_output_weight(self, data_torch): + assert data_torch.size() == (5120, 5120) + data_torch = data_torch.reshape(5120, 8, 5, 128) + data_torch = torch.permute(data_torch, (0, 2, 1, 3)) + data_torch = torch.reshape(data_torch, (5120, 5120)) + return data_torch + def write_tensors(self): block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) @@ -1016,6 +1030,12 @@ def write_tensors(self): print(f"Can not map tensor {name!r}") sys.exit() + # shuffle for broadcasting of gqa in ggml_mul_mat + if new_name.endswith("attn_q.weight"): + data_torch = self.shuffle_attn_q_weight(data_torch) + elif new_name.endswith("attn_output.weight"): + data_torch = self.shuffle_attn_output_weight(data_torch) + old_dtype = data_torch.dtype # convert any unsupported data types to float32 diff --git a/llama.cpp b/llama.cpp index 25a1a4a7c0567..3579c8960a825 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5601,11 +5601,14 @@ struct llm_build_context { 0); cb(k, "k", il); + /* // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); cb(k_repeated, "k_repeated", il); struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); + */ + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); @@ -5620,11 +5623,14 @@ struct llm_build_context { 0); cb(v, "v", il); + /* // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); cb(k_repeated, "v_repeated", il); struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); + */ + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); cb(kqv, "kqv", il); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); From 26340a1902f8c1f8ce06c51a0e1a2d2a327fe67c Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 24 Dec 2023 18:14:12 +0900 Subject: [PATCH 14/18] remove plamo_llm_build_kqv and use llm_build_kqv --- llama.cpp | 76 +++---------------------------------------------------- 1 file changed, 3 insertions(+), 73 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3579c8960a825..53390ddbf90b8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5573,79 +5573,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - auto plamo_llm_build_kqv = []( - struct ggml_context * ctx, - const llama_hparams & hparams, - const llama_kv_cache & kv, - struct ggml_tensor * wo, - struct ggml_tensor * q_cur, - struct ggml_tensor * kq_mask, - int64_t n_ctx, - int32_t n_tokens, - int32_t n_kv, - const llm_build_cb & cb, - int il) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); - cb(q, "q", il); - - struct ggml_tensor * k = - ggml_view_3d(ctx, kv.k_l[il], - n_embd_head, n_kv, n_head_kv, - ggml_row_size(kv.k_l[il]->type, n_embd_gqa), - ggml_row_size(kv.k_l[il]->type, n_embd_head), - 0); - cb(k, "k", il); - - /* - // we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * k_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k->ne[0], k->ne[1], q->ne[2]); - cb(k_repeated, "k_repeated", il); - - struct ggml_tensor * kq = ggml_mul_mat(ctx, ggml_repeat(ctx, k, k_repeated), q); - */ - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - - kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head))); - cb(kq, "kq_soft_max_ext", il); - - // split cached v into n_head heads - struct ggml_tensor * v = - ggml_view_3d(ctx, kv.v_l[il], - n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv.v_l[il])*n_ctx, - ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head, - 0); - cb(v, "v", il); - - /* - // we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att - struct ggml_tensor * v_repeated = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, v->ne[0], v->ne[1], q->ne[2]); - cb(k_repeated, "v_repeated", il); - - struct ggml_tensor * kqv = ggml_mul_mat(ctx, ggml_repeat(ctx, v, v_repeated), kq); - */ - struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); - cb(kqv, "kqv", il); - - struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); - cb(kqv_merged, "kqv_merged", il); - - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens); - cb(cur, "kqv_merged_cont", il); - - cur = ggml_mul_mat(ctx, wo, cur); - return cur; - }; - - cur = plamo_llm_build_kqv(ctx0, hparams, kv_self, - model.layers[il].wo, - Qcur, KQ_mask, n_ctx, n_tokens, n_kv, cb, il); + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 0.0f, cb, il); cb(cur, "kqv_out", il); } struct ggml_tensor * sa_out = cur; From 700f7c600a47cca3971acdeff6edf2554121a438 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 24 Dec 2023 18:15:20 +0900 Subject: [PATCH 15/18] fix style --- llama.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/llama.cpp b/llama.cpp index 53390ddbf90b8..158422988a2f6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5543,7 +5543,7 @@ struct llm_build_context { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "attention_norm", il); + cb(cur, "attn_norm", il); struct ggml_tensor * attention_norm = cur; @@ -5603,13 +5603,10 @@ struct llm_build_context { cur = inpL; - // norm - { - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, NULL, - LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); - } + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); From 307481f28da0150e1e5ba793ae71a4249cad5086 Mon Sep 17 00:00:00 2001 From: okada Date: Sun, 24 Dec 2023 18:57:53 +0900 Subject: [PATCH 16/18] update --- llama.cpp | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/llama.cpp b/llama.cpp index 32305c338a1bf..5e05b41267344 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3678,13 +3678,6 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); - - if (backend_norm == GGML_BACKEND_GPU) { - vram_weights += ggml_nbytes(model.output_norm); - } - if (backend_output == GGML_BACKEND_GPU_SPLIT) { - vram_weights += ggml_nbytes(model.output); - } } const uint32_t n_ff = hparams.n_ff; @@ -3708,14 +3701,7 @@ static bool llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); - layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); - - if (backend == GGML_BACKEND_GPU) { - vram_weights += - ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + - ggml_nbytes(layer.ffn_gate) + ggml_nbytes(layer.ffn_up) + ggml_nbytes(layer.ffn_down); - } + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); } } break; default: @@ -5706,9 +5692,9 @@ struct llm_build_context { llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - cur = llm_build_kqv(ctx0, hparams, kv_self, + cur = llm_build_kqv(ctx0, model, hparams, kv_self, model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 0.0f, cb, il); + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); cb(cur, "kqv_out", il); } struct ggml_tensor * sa_out = cur; From eedd43457583769aa5a31c9ab1c01f3996f2b3fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Dec 2023 15:30:12 +0200 Subject: [PATCH 17/18] llama : remove obsolete KQ_scale --- llama.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5e05b41267344..90dc1b11ff692 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5643,10 +5643,6 @@ struct llm_build_context { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); cb(inp_pos, "inp_pos", -1); - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); cb(KQ_mask, "KQ_mask", -1); From 1949c95598536d6e7cdcb3e9ed99862c191dc9a0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Dec 2023 15:33:31 +0200 Subject: [PATCH 18/18] plamo : fix tensor names for correct GPU offload --- llama.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index 90dc1b11ff692..03d1a4b667e0a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5704,13 +5704,14 @@ struct llm_build_context { model.layers[il].ffn_gate, NULL, model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "mlp_out", il); + cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, sa_out); - cb(cur, "mlp_out + sa_out", il); + cb(cur, "l_out", il); + cur = ggml_add(ctx0, cur, inpL); - cb(cur, "mlp_out + sa_out + inpL", il); + cb(cur, "l_out", il); // input for next layer inpL = cur;