From 69ff61397d2b7b550dcdda4a35b35128892408b0 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Thu, 14 Mar 2024 17:21:56 +0100 Subject: [PATCH] llama : support models without vocabulary (#5798) * additional methods to read model and ctx parameters * vocab size as a part of a model metadata * models without vocabulary, convert.py part * models without vocabulary, llama.cpp part * PR clean up * converter scrypt fixes * llama_vocab_type update (renamed the new key) * pr review fixes * revert function renaming * one more NoVocab assert --- convert.py | 126 +++++++++++++++++++++--------------- gguf-py/gguf/constants.py | 2 + gguf-py/gguf/gguf_writer.py | 3 + llama.cpp | 92 +++++++++++++++++--------- llama.h | 7 +- 5 files changed, 142 insertions(+), 88 deletions(-) diff --git a/convert.py b/convert.py index c15f8c47ea4f7..161430f3e717e 100755 --- a/convert.py +++ b/convert.py @@ -332,6 +332,9 @@ def load(model_plus: ModelPlus) -> Params: # class BpeVocab: + tokenizer_model = "gpt2" + name = "bpe" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) if isinstance(self.bpe_tokenizer.get('model'), dict): @@ -390,6 +393,9 @@ def __repr__(self) -> str: class SentencePieceVocab: + tokenizer_model = "llama" + name = "spm" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: dict[str, int] @@ -453,6 +459,9 @@ def __repr__(self) -> str: class HfVocab: + tokenizer_model = "llama" + name = "hfft" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None: try: from transformers import AutoTokenizer @@ -553,7 +562,15 @@ def __repr__(self) -> str: return f"" -Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab" +class NoVocab: + tokenizer_model = "no_vocab" + name = "no_vocab" + + def __repr__(self) -> str: + return "" + + +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab" # @@ -935,8 +952,10 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N # Handle special case where the model's vocab size is not set if params.n_vocab == -1: raise ValueError( - f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?" + f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}" ) + if isinstance(vocab, NoVocab): + return # model has no vocab # Check for a vocab size mismatch if params.n_vocab == vocab.vocab_size: @@ -977,6 +996,7 @@ def add_meta_arch(self, params: Params) -> None: name = str(params.path_model.parent).split('/')[-1] self.gguf.add_name (name) + self.gguf.add_vocab_size (params.n_vocab) self.gguf.add_context_length (params.n_ctx) self.gguf.add_embedding_length (params.n_embd) self.gguf.add_block_count (params.n_layer) @@ -1013,21 +1033,9 @@ def add_meta_arch(self, params: Params) -> None: if params.ftype is not None: self.gguf.add_file_type(params.ftype) - def handle_tokenizer_model(self, vocab: Vocab) -> str: - # Map the vocab types to the supported tokenizer models - tokenizer_model = { - SentencePieceVocab: "llama", - HfVocab: "llama", - BpeVocab: "gpt2", - }.get(type(vocab)) - - # Block if vocab type is not predefined - if tokenizer_model is None: - raise ValueError("Unknown vocab type: Not supported") - - return tokenizer_model - def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + assert not isinstance(vocab, NoVocab) + tokens = [] scores = [] toktypes = [] @@ -1043,11 +1051,8 @@ def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list return tokens, scores, toktypes def add_meta_vocab(self, vocab: Vocab) -> None: - # Handle the tokenizer model - tokenizer_model = self.handle_tokenizer_model(vocab) - # Ensure that tokenizer_model is added to the GGUF model - self.gguf.add_tokenizer_model(tokenizer_model) + self.gguf.add_tokenizer_model(vocab.tokenizer_model) # Extract model vocabulary for model conversion tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) @@ -1074,6 +1079,26 @@ def write_meta(self) -> None: def write_tensor_info(self) -> None: self.gguf.write_ti_data_to_file() + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + print( + f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + self.gguf.write_tensor_data(ndarray) + def close(self) -> None: self.gguf.close() @@ -1082,7 +1107,7 @@ def write_vocab_only( fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab = pad_vocab) + check_vocab_size(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1120,8 +1145,11 @@ def write_all( # meta data of.add_meta_arch(params) - of.add_meta_vocab(vocab) - of.add_meta_special_vocab(svocab) + if isinstance(vocab, NoVocab): + of.gguf.add_tokenizer_model(vocab.tokenizer_model) + else: + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) # tensor info for name, lazy_tensor in model.items(): @@ -1131,24 +1159,7 @@ def write_all( of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) - if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map( - OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, - use_processpool_executor=True, - ) - else: - ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) - - start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): - elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) - padi = len(str(len(model))) - print( - f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" - ) - of.gguf.write_tensor_data(ndarray) + of.write_tensor_data(ftype, model, concurrency) of.close() @@ -1309,8 +1320,8 @@ def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]: return vtype, path raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") - def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab: - load_merges = vocabtype == "bpe" + def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocab.name == "bpe" n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None return gguf.SpecialVocab( model_parent_path, @@ -1319,30 +1330,34 @@ def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: n_vocab=n_vocab, ) - def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: vocab_type, path = self._select_file(vocab_types) print(f"Loading vocab file {path!r}, type {vocab_type!r}") added_tokens_path = path.parent / "added_tokens.json" - vocab: Vocab if vocab_type == "bpe": - vocab = BpeVocab( + return BpeVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocab_type == "spm": - vocab = SentencePieceVocab( + if vocab_type == "spm": + return SentencePieceVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocab_type == "hfft": - vocab = HfVocab( + if vocab_type == "hfft": + return HfVocab( path.parent, added_tokens_path if added_tokens_path.exists() else None ) + raise ValueError(vocab_type) + + def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + vocab: Vocab + if len(vocab_types) == 1 and "no_vocab" in vocab_types: + vocab = NoVocab() else: - raise ValueError(vocab_type) + vocab = self._create_vocab_by_path(vocab_types) # FIXME: Respect --vocab-dir? special_vocab = self._create_special_vocab( vocab, - vocab_type, model_parent_path, ) return vocab, special_vocab @@ -1380,6 +1395,7 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") @@ -1392,6 +1408,10 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") args = parser.parse_args(args_in) + if args.no_vocab: + if args.vocab_only: + raise ValueError("no need to specify --vocab-only if using --no-vocab") + args.vocab_type = "no_vocab" if args.dump_single: model_plus = lazy_load_file(args.model) @@ -1442,7 +1462,7 @@ def main(args_in: list[str] | None = None) -> None: print(f"Wrote {outfile}") return - if model_plus.vocab is not None and args.vocab_dir is None: + if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: vocab = model_plus.vocab print(f"Vocab info: {vocab}") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 99f71f0a1aa29..2d7cf16c14ed1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -32,6 +32,7 @@ class General: FILE_TYPE = "general.file_type" class LLM: + VOCAB_SIZE = "{arch}.vocab_size" CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" BLOCK_COUNT = "{arch}.block_count" @@ -752,6 +753,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE # LLM +KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 4d389be951d72..81b2eb884d485 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -321,6 +321,9 @@ def add_custom_alignment(self, alignment: int) -> None: self.data_alignment = alignment self.add_uint32(Keys.General.ALIGNMENT, alignment) + def add_vocab_size(self, size: int) -> None: + self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size) + def add_context_length(self, length: int) -> None: self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) diff --git a/llama.cpp b/llama.cpp index eb48b1e90c8d1..10fd53469eb6f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -258,6 +258,7 @@ enum llm_kv { LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_HF_REPO, + LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, @@ -321,6 +322,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, @@ -3242,10 +3244,11 @@ static const char * llama_model_type_name(e_model type) { static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ switch (type) { - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + default: return "unknown"; } } @@ -3277,14 +3280,14 @@ static void llm_load_hparams( ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); // get hparams kv - ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); - ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); - ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); - ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); - ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); - ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); - ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); @@ -3645,30 +3648,25 @@ static void llm_load_vocab( const auto kv = LLM_KV(model.arch); - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } - - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - } - // determine vocab type { std::string tokenizer_name; ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); - if (tokenizer_name == "llama") { + if (tokenizer_name == "no_vocab") { + vocab.type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.linefeed_id = -1; + + return; + } else if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens @@ -3734,6 +3732,23 @@ static void llm_load_vocab( } } + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); vocab.id_to_token.resize(n_vocab); @@ -5023,7 +5038,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam llm_load_print_meta(ml, model); - if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && + model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); } @@ -9361,26 +9377,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { } static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; } static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; } static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; } static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { @@ -9401,6 +9423,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { @@ -10232,6 +10255,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } } break; + case LLAMA_VOCAB_TYPE_NONE: + GGML_ASSERT(false); } return output; @@ -13138,7 +13163,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { } int32_t llama_n_vocab(const struct llama_model * model) { - return model->vocab.id_to_token.size(); + return model->hparams.n_vocab; } int32_t llama_n_ctx_train(const struct llama_model * model) { @@ -13962,14 +13987,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id } const char * llama_token_get_text(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].text.c_str(); } float llama_token_get_score(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].score; } llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].type; } diff --git a/llama.h b/llama.h index 2d16cc9b9fa2c..90aa5372e740b 100644 --- a/llama.h +++ b/llama.h @@ -59,9 +59,10 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece - LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding - LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding + LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece }; // note: these values should be synchronized with ggml_rope