diff --git a/README.md b/README.md index c90159c3..dbfeb8ba 100644 --- a/README.md +++ b/README.md @@ -171,4 +171,7 @@ the converter to better facilitate scripted jobs. **2023-09-27**: Prebuilt wheels are now available, credit to [@jllllll](https://github.com/jllllll). They're on the [releases page here](https://github.com/turboderp/exllamav2/releases). A solution to installing prebuilt wheels straight -from PyPI is still pending. Updated installation instructions above. \ No newline at end of file +from PyPI is still pending. Updated installation instructions above. + +**2023-10-03**: Added support for extended vocabularies and alternative BOS/EOS/UNK tokens and the ability to +encode/decode sequences with special tokens. Added Orca template to the chatbot example. \ No newline at end of file diff --git a/examples/chat.py b/examples/chat.py index 72292d51..622b4873 100644 --- a/examples/chat.py +++ b/examples/chat.py @@ -19,11 +19,14 @@ ) from chat_formatting import CodeBlockFormatter +from chat_prompts import prompt_formats +prompt_formats_list = list(prompt_formats.keys()) # Options parser = argparse.ArgumentParser(description = "Simple Llama2 chat example for ExLlamaV2") -parser.add_argument("-mode", "--mode", choices = ["llama", "raw", "codellama"], help = "Chat mode. Use llama for Llama 1/2 chat finetunes.") +parser.add_argument("-modes", "--modes", action = "store_true", help = "List available modes and exit.") +parser.add_argument("-mode", "--mode", choices = prompt_formats_list, help = "Chat mode. Use llama for Llama 1/2 chat finetunes.") parser.add_argument("-un", "--username", type = str, default = "User", help = "Username when using raw chat mode") parser.add_argument("-bn", "--botname", type = str, default = "Chatbort", help = "Bot name when using raw chat mode") parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt") @@ -37,86 +40,61 @@ parser.add_argument("-resc", "--response_chunk", type = int, default = 250, help = "Space to reserve in context for reply, default = 250") parser.add_argument("-ncf", "--no_code_formatting", action = "store_true", help = "Disable code formatting/syntax highlighting") -# Initialize model and tokenizer +# Arrrgs model_init.add_args(parser) args = parser.parse_args() -model_init.check_args(args) -model_init.print_options(args) -model, tokenizer = model_init.init(args) -# Create cache - -cache = ExLlamaV2Cache(model) +# Prompt templates/modes -# Prompt templates +if args.modes: + print(" -- Available formats:") + for k, v in prompt_formats.items(): + print(f" -- {k:12} : {v().description}") + sys.exit() username = args.username botname = args.botname system_prompt = args.system_prompt -mode = args.mode - -if mode == "llama" or mode == "codellama": - - if not system_prompt: - - if mode == "llama": - - system_prompt = \ - """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + \ - """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + \ - """Please ensure that your responses are socially unbiased and positive in nature.""" - - elif mode == "codellama": - - system_prompt = \ - """You are a helpful coding assistant. Always answer as helpfully as possible.""" - - first_prompt = \ - """[INST] <>\n<|system_prompt|>\n<>\n\n<|user_prompt|> [/INST]""" - subs_prompt = \ - """[INST] <|user_prompt|> [/INST]""" - -elif mode == "raw": +if args.mode is None: + print(" ## Error: No mode specified.") + sys.exit() - if not system_prompt: +prompt_format = prompt_formats[args.mode]() +prompt_format.botname = botname +prompt_format.username = username +if system_prompt is None: system_prompt = prompt_format.default_system_prompt() - system_prompt = \ - f"""This is a conversation between a helpful AI assistant named {botname} and a """ + ("""user named {username}.""" if username != "User" else """user.""") +# Initialize model and tokenizer - first_prompt = \ - f"""<|system_prompt|>\n{username}: <|user_prompt|>\n{botname}:""" +model_init.check_args(args) +model_init.print_options(args) +model, tokenizer = model_init.init(args) - subs_prompt = \ - f"""{username}: <|user_prompt|>\n{botname}:""" +# Create cache -else: +cache = ExLlamaV2Cache(model) - print(" ## Error: Incorrect/no mode specified.") - sys.exit() # Chat context def format_prompt(user_prompt, first): - global system_prompt, first_prompt, subs_prompt + global system_prompt, prompt_format if first: - return first_prompt \ + return prompt_format.first_prompt() \ .replace("<|system_prompt|>", system_prompt) \ .replace("<|user_prompt|>", user_prompt) else: - return subs_prompt \ + return prompt_format.subs_prompt() \ .replace("<|user_prompt|>", user_prompt) def encode_prompt(text): - global tokenizer, mode + global tokenizer, prompt_format - if mode == "llama" or mode == "codellama": - return tokenizer.encode(text, add_bos = True) - - if mode == "raw": - return tokenizer.encode(text) + add_bos, add_eos, encode_special_tokens = prompt_format.encoding_options() + return tokenizer.encode(text, add_bos = add_bos, add_eos = add_eos, encode_special_tokens = encode_special_tokens) user_prompts = [] responses_ids = [] @@ -130,7 +108,8 @@ def get_tokenized_context(max_len): for turn in range(len(user_prompts)): - up_ids = encode_prompt(format_prompt(user_prompts[turn], context.shape[-1] == 0)) + up_text = format_prompt(user_prompts[turn], context.shape[-1] == 0) + up_ids = encode_prompt(up_text) context = torch.cat([context, up_ids], dim=-1) if turn < len(responses_ids): @@ -161,13 +140,7 @@ def get_tokenized_context(max_len): # Stop conditions -if mode == "llama" or mode == "codellama": - - generator.set_stop_conditions([tokenizer.eos_token_id]) - -if mode == "raw": - - generator.set_stop_conditions([username + ":", username[0:1] + ":", username.upper() + ":", username.lower() + ":", tokenizer.eos_token_id]) +generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer)) # ANSI color codes @@ -175,6 +148,7 @@ def get_tokenized_context(max_len): col_user = "\u001b[33;1m" # Yellow col_bot = "\u001b[34;1m" # Blue col_error = "\u001b[31;1m" # Magenta +col_sysprompt = "\u001b[37;1m" # Grey # Code block formatting @@ -188,6 +162,11 @@ def get_tokenized_context(max_len): # Main loop +print(f" -- Prompt format: {args.mode}") +print(f" -- System prompt:") +print() +print(col_sysprompt + system_prompt.strip() + col_default) + while True: # Get user prompt @@ -207,7 +186,7 @@ def get_tokenized_context(max_len): # Stream response - if mode == "raw": + if prompt_format.print_bot_name(): print(col_bot + botname + ": " + col_default, end = "") @@ -288,7 +267,7 @@ def get_tokenized_context(max_len): if eos: - if mode == "llama" or mode == "codellama": + if prompt_format.print_extra_newline(): print() break diff --git a/examples/chat_prompts.py b/examples/chat_prompts.py new file mode 100644 index 00000000..606e6889 --- /dev/null +++ b/examples/chat_prompts.py @@ -0,0 +1,177 @@ + +class PromptFormat: + + botname = "Chatbort" + username = "User" + + def __init__(self): + pass + + # + + def default_system_prompt(self): + raise NotImplementedError + + def first_prompt(self): + raise NotImplementedError + + def subs_prompt(self): + raise NotImplementedError + + def stop_conditions(self, tokenizer): + raise NotImplementedError + + def encoding_options(self): # (add_bos, add_eos, encode_special_tokens) + raise NotImplementedError + + def print_bot_name(self): + return False + + def print_extra_newline(self): + return False + + +class PromptFormat_raw(PromptFormat): + + description = "Model-agnostic mode simulating a raw chatlog" + + def __init__(self): + super().__init__() + pass + + def default_system_prompt(self): + return \ + f"""This is a conversation between a helpful AI assistant named {self.botname} and a """ + \ + (f"""user named {self.username}.""" if self.username != "User" else """user.""") + + def first_prompt(self): + return \ + f"""<|system_prompt|>\n{self.username}: <|user_prompt|>\n{self.botname}:""" + + def subs_prompt(self): + return \ + f"""{self.username}: <|user_prompt|>\n{self.botname}:""" + + def stop_conditions(self, tokenizer): + return \ + [self.username + ":", + self.username[0:1] + ":", + self.username.upper() + ":", + self.username.lower() + ":", + tokenizer.eos_token_id] + + def encoding_options(self): + return False, False, False + + def print_bot_name(self): + return True + + +class PromptFormat_llama(PromptFormat): + + description = "Llama-chat, Llama2-chat and Mistral-instruct models" + + def __init__(self): + super().__init__() + pass + + def default_system_prompt(self): + return \ + """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + \ + """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + \ + """Please ensure that your responses are socially unbiased and positive in nature.""" + + def first_prompt(self): + return \ + """[INST] <>\n<|system_prompt|>\n<>\n\n<|user_prompt|> [/INST]""" + + def subs_prompt(self): + return \ + """[INST] <|user_prompt|> [/INST]""" + + def stop_conditions(self, tokenizer): + return \ + [tokenizer.eos_token_id] + + def encoding_options(self): + return True, False, False + + def print_extra_newline(self): + return True + + +class PromptFormat_codellama(PromptFormat_llama): + + description = "CodeLlama-instruct" + + def __init__(self): + super().__init__() + pass + + def default_system_prompt(self): + return \ + """You are a helpful coding assistant. Always answer as helpfully as possible.""" + + +class PromptFormat_chatml(PromptFormat): + + description = "ChatML format, as used by e.g. (Mistral)Orca" + + def __init__(self): + super().__init__() + pass + + def default_system_prompt(self): + return \ + f"""You are {self.botname}, a large language model. Answer as concisely as possible.""" + + def first_prompt(self): + return \ + """<|im_start|>system\n""" + \ + """<|system_prompt|>\n""" + \ + """<|im_end|>\n""" + \ + """<|im_start|>user\n""" + \ + """<|user_prompt|><|im_end|>\n""" + \ + """<|im_start|>assistant\n""" + + def subs_prompt(self): + return \ + """<|im_end|>\n""" + \ + """<|im_start|>user\n""" + \ + """<|user_prompt|><|im_end|>\n""" + \ + """<|im_start|>assistant\n""" + + def stop_conditions(self, tokenizer): + return \ + [tokenizer.eos_token_id, + """<|im_end|>"""] + + def encoding_options(self): + return False, False, True + + def print_extra_newline(self): + return True + + +class PromptFormat_tinyllama(PromptFormat_chatml): + + description = "ChatML format, but ignoring special/added tokens. Use for TinyLlama-chat v0.3" + + def encoding_options(self): + return False, False, False + + +prompt_formats = \ +{ + "raw": PromptFormat_raw, + "llama": PromptFormat_llama, + "codellama": PromptFormat_codellama, + "chatml": PromptFormat_chatml, + "tinyllama": PromptFormat_tinyllama, +} + + + + + + diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py index 34940976..f4ca08fb 100644 --- a/exllamav2/generator/base.py +++ b/exllamav2/generator/base.py @@ -47,7 +47,9 @@ def generate_simple(self, prompt: str or list, gen_settings: ExLlamaV2Sampler.Settings, num_tokens: int, seed = None, - token_healing = False): + token_healing = False, + encode_special_tokens = False, + decode_special_tokens = False ): # Apply seed @@ -56,7 +58,7 @@ def generate_simple(self, prompt: str or list, # Tokenize input and produce padding mask if needed batch_size = 1 if isinstance(prompt, str) else len(prompt) - ids = self.tokenizer.encode(prompt) + ids = self.tokenizer.encode(prompt, encode_special_tokens = encode_special_tokens) overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len if overflow > 0: ids = ids[:, overflow:] @@ -93,7 +95,7 @@ def generate_simple(self, prompt: str or list, # Decode - text = self.tokenizer.decode(self.sequence_ids) + text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens) if isinstance(prompt, str): return text[0] return text diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index 37b13bd9..c9020162 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -38,7 +38,8 @@ def clone(self): def disallow_tokens(self, tokenizer, tokens): if self.token_bias is None: - self.token_bias = torch.zeros((tokenizer.config.vocab_size,), dtype = torch.float) + padding = -tokenizer.config.vocab_size % 32 + self.token_bias = torch.zeros((tokenizer.config.vocab_size + padding,), dtype = torch.float) self.token_bias[tokens] = float("-inf") diff --git a/exllamav2/tokenizer.py b/exllamav2/tokenizer.py index 84e3588c..219bd69f 100644 --- a/exllamav2/tokenizer.py +++ b/exllamav2/tokenizer.py @@ -1,6 +1,7 @@ from exllamav2.config import ExLlamaV2Config from sentencepiece import SentencePieceProcessor import torch +import os, json, re class ExLlamaV2Tokenizer: @@ -35,6 +36,10 @@ def __init__(self, children = None, leaf = None): char_trie: Trie = None char_trie_ci: Trie = None + extended_id_to_piece = {} + extended_piece_to_id = {} + special_delimiters = None + tokenized_str_cache = {} max_cached_strings = 100 @@ -44,11 +49,45 @@ def __init__(self, config, lazy_init = False): self.tokenizer = SentencePieceProcessor(model_file = self.config.tokenizer_path) + # Load added_tokens.json if present + + added_tokens_path = os.path.join(self.config.model_dir, "added_tokens.json") + if os.path.exists(added_tokens_path): + with open(added_tokens_path) as f: + self.extended_piece_to_id = json.load(f) + + self.extended_id_to_piece = { v: k for k, v in self.extended_piece_to_id.items() } + + # Get control token IDs + + # self.eos_token_id = self.tokenizer.eos_id() + # self.bos_token_id = self.tokenizer.bos_id() + # self.unk_token_id = config.unk_token_id + self.unk_token_id = self.tokenizer.unk_id() - self.eos_token_id = self.tokenizer.eos_id() - self.bos_token_id = self.tokenizer.bos_id() + self.eos_token_id = config.eos_token_id + self.bos_token_id = config.bos_token_id + + # Get control token strings + + try: self.unk_token = self.extended_id_to_piece[self.unk_token_id] or self.tokenizer.id_to_piece(self.unk_token_id) + except: pass + try: self.bos_token = self.extended_id_to_piece[self.bos_token_id] or self.tokenizer.id_to_piece(self.bos_token_id) + except: pass + try: self.eos_token = self.extended_id_to_piece[self.eos_token_id] or self.tokenizer.id_to_piece(self.eos_token_id) + except: pass + self.pad_token_id = 0 + # Make sure extended vocab contains control tokens + + self.extended_id_to_piece[self.unk_token_id] = self.unk_token + self.extended_id_to_piece[self.bos_token_id] = self.bos_token + self.extended_id_to_piece[self.eos_token_id] = self.eos_token + self.extended_piece_to_id[self.unk_token] = self.unk_token_id + self.extended_piece_to_id[self.bos_token] = self.bos_token_id + self.extended_piece_to_id[self.eos_token] = self.eos_token_id + # Create dictionaries on init if not lazy_init: @@ -68,17 +107,39 @@ def single_token(self, token_id: int): return torch.tensor([[token_id]], dtype = torch.long) + # Encode string with special tokens + + def encode_special(self, text: str): + + if self.special_delimiters is None: + self.special_delimiters = re.compile("(" + "|".join(map(re.escape, self.extended_piece_to_id.keys())) + ")") + + split = self.special_delimiters.split(text) + encoded = [] + + i = 0 + while i < len(split): + if split[i] != "": encoded += self.tokenizer.EncodeAsIds(split[i]) + if i + 1 < len(split): encoded += [self.extended_piece_to_id[split[i + 1]]] + i += 2 + + return encoded + + # Encode string # TODO: Handle added tokens for "special" models - def encode(self, text, add_bos = False, add_eos = False): + def encode(self, text, add_bos = False, add_eos = False, encode_special_tokens = False): if isinstance(text, list): # text is a list of strings - list_ids = self.tokenizer.EncodeAsIds(text) + if encode_special_tokens: + list_ids = [self.encode_special(t) for t in text] + else: + list_ids = self.tokenizer.EncodeAsIds(text) if add_bos: for ids in list_ids: ids.insert(0, self.bos_token_id) @@ -99,7 +160,10 @@ def encode(self, text, add_bos = False, add_eos = False): # text is a single string - ids = self.tokenizer.EncodeAsIds(text) + if encode_special_tokens: + ids = self.encode_special(text) + else: + ids = self.tokenizer.EncodeAsIds(text) if add_bos: ids.insert(0, self.bos_token_id) @@ -109,18 +173,45 @@ def encode(self, text, add_bos = False, add_eos = False): return torch.tensor(ids).to(torch.long).unsqueeze(0) + # Decode sequence with or without special tokens + + def decode_(self, seq, decode_special_tokens): + + if not decode_special_tokens: + + max_token = self.tokenizer.vocab_size() + seq = [t for t in seq if (t != self.pad_token_id and t < max_token)] + if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)] + return self.tokenizer.Decode(seq) + + else: + + text = "" + start = 0 + end = 0 + while end < len(seq): + if seq[end] in self.extended_id_to_piece: + if end > start: text += self.tokenizer.Decode(seq[start : end]) + text += self.extended_id_to_piece[seq[end]] + end += 1 + start = end + else: + end += 1 + if end > start: text += self.tokenizer.Decode(seq[start : end]) + + return text + + # Decode IDs - def decode(self, ids): + def decode(self, ids, decode_special_tokens = False): if ids.dim() > 1: texts = [] for i in range(ids.shape[0]): seq = ids[i].tolist() - seq = [t for t in seq if t != self.pad_token_id] - if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)] - texts.append(self.tokenizer.Decode(seq)) + texts.append(self.decode_(seq, decode_special_tokens)) return texts else: @@ -159,6 +250,12 @@ def get_id_to_piece_list(self): (p.replace("▁", " ") if not p.startswith("<") else self.tokenizer.decode(idx)) for idx, p in enumerate(self.tokenizer.id_to_piece(all_tokens)) ] + + i = self.tokenizer.vocab_size() + while i in self.extended_id_to_piece: + self.id_to_piece.append(self.extended_id_to_piece[i]) + i += 1 + return self.id_to_piece @@ -209,6 +306,10 @@ def get_prefix_id_to_ids_dict(self): self.prefix_id_to_ids = { piece_to_id[piece]: ids for piece, ids in prefix_to_ids.items() } + for i in range(self.config.vocab_size): + if i not in self.prefix_id_to_ids: + self.prefix_id_to_ids[i] = [i] + return self.prefix_id_to_ids diff --git a/exllamav2/version.py b/exllamav2/version.py index 221ce5dc..c1336bdb 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.0.4" \ No newline at end of file +__version__ = "0.0.5" \ No newline at end of file