diff --git a/README.md b/README.md
index c90159c3..dbfeb8ba 100644
--- a/README.md
+++ b/README.md
@@ -171,4 +171,7 @@ the converter to better facilitate scripted jobs.
**2023-09-27**: Prebuilt wheels are now available, credit to [@jllllll](https://github.com/jllllll). They're on the
[releases page here](https://github.com/turboderp/exllamav2/releases). A solution to installing prebuilt wheels straight
-from PyPI is still pending. Updated installation instructions above.
\ No newline at end of file
+from PyPI is still pending. Updated installation instructions above.
+
+**2023-10-03**: Added support for extended vocabularies and alternative BOS/EOS/UNK tokens and the ability to
+encode/decode sequences with special tokens. Added Orca template to the chatbot example.
\ No newline at end of file
diff --git a/examples/chat.py b/examples/chat.py
index 72292d51..622b4873 100644
--- a/examples/chat.py
+++ b/examples/chat.py
@@ -19,11 +19,14 @@
)
from chat_formatting import CodeBlockFormatter
+from chat_prompts import prompt_formats
+prompt_formats_list = list(prompt_formats.keys())
# Options
parser = argparse.ArgumentParser(description = "Simple Llama2 chat example for ExLlamaV2")
-parser.add_argument("-mode", "--mode", choices = ["llama", "raw", "codellama"], help = "Chat mode. Use llama for Llama 1/2 chat finetunes.")
+parser.add_argument("-modes", "--modes", action = "store_true", help = "List available modes and exit.")
+parser.add_argument("-mode", "--mode", choices = prompt_formats_list, help = "Chat mode. Use llama for Llama 1/2 chat finetunes.")
parser.add_argument("-un", "--username", type = str, default = "User", help = "Username when using raw chat mode")
parser.add_argument("-bn", "--botname", type = str, default = "Chatbort", help = "Bot name when using raw chat mode")
parser.add_argument("-sp", "--system_prompt", type = str, help = "Use custom system prompt")
@@ -37,86 +40,61 @@
parser.add_argument("-resc", "--response_chunk", type = int, default = 250, help = "Space to reserve in context for reply, default = 250")
parser.add_argument("-ncf", "--no_code_formatting", action = "store_true", help = "Disable code formatting/syntax highlighting")
-# Initialize model and tokenizer
+# Arrrgs
model_init.add_args(parser)
args = parser.parse_args()
-model_init.check_args(args)
-model_init.print_options(args)
-model, tokenizer = model_init.init(args)
-# Create cache
-
-cache = ExLlamaV2Cache(model)
+# Prompt templates/modes
-# Prompt templates
+if args.modes:
+ print(" -- Available formats:")
+ for k, v in prompt_formats.items():
+ print(f" -- {k:12} : {v().description}")
+ sys.exit()
username = args.username
botname = args.botname
system_prompt = args.system_prompt
-mode = args.mode
-
-if mode == "llama" or mode == "codellama":
-
- if not system_prompt:
-
- if mode == "llama":
-
- system_prompt = \
- """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + \
- """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + \
- """Please ensure that your responses are socially unbiased and positive in nature."""
-
- elif mode == "codellama":
-
- system_prompt = \
- """You are a helpful coding assistant. Always answer as helpfully as possible."""
-
- first_prompt = \
- """[INST] <>\n<|system_prompt|>\n<>\n\n<|user_prompt|> [/INST]"""
- subs_prompt = \
- """[INST] <|user_prompt|> [/INST]"""
-
-elif mode == "raw":
+if args.mode is None:
+ print(" ## Error: No mode specified.")
+ sys.exit()
- if not system_prompt:
+prompt_format = prompt_formats[args.mode]()
+prompt_format.botname = botname
+prompt_format.username = username
+if system_prompt is None: system_prompt = prompt_format.default_system_prompt()
- system_prompt = \
- f"""This is a conversation between a helpful AI assistant named {botname} and a """ + ("""user named {username}.""" if username != "User" else """user.""")
+# Initialize model and tokenizer
- first_prompt = \
- f"""<|system_prompt|>\n{username}: <|user_prompt|>\n{botname}:"""
+model_init.check_args(args)
+model_init.print_options(args)
+model, tokenizer = model_init.init(args)
- subs_prompt = \
- f"""{username}: <|user_prompt|>\n{botname}:"""
+# Create cache
-else:
+cache = ExLlamaV2Cache(model)
- print(" ## Error: Incorrect/no mode specified.")
- sys.exit()
# Chat context
def format_prompt(user_prompt, first):
- global system_prompt, first_prompt, subs_prompt
+ global system_prompt, prompt_format
if first:
- return first_prompt \
+ return prompt_format.first_prompt() \
.replace("<|system_prompt|>", system_prompt) \
.replace("<|user_prompt|>", user_prompt)
else:
- return subs_prompt \
+ return prompt_format.subs_prompt() \
.replace("<|user_prompt|>", user_prompt)
def encode_prompt(text):
- global tokenizer, mode
+ global tokenizer, prompt_format
- if mode == "llama" or mode == "codellama":
- return tokenizer.encode(text, add_bos = True)
-
- if mode == "raw":
- return tokenizer.encode(text)
+ add_bos, add_eos, encode_special_tokens = prompt_format.encoding_options()
+ return tokenizer.encode(text, add_bos = add_bos, add_eos = add_eos, encode_special_tokens = encode_special_tokens)
user_prompts = []
responses_ids = []
@@ -130,7 +108,8 @@ def get_tokenized_context(max_len):
for turn in range(len(user_prompts)):
- up_ids = encode_prompt(format_prompt(user_prompts[turn], context.shape[-1] == 0))
+ up_text = format_prompt(user_prompts[turn], context.shape[-1] == 0)
+ up_ids = encode_prompt(up_text)
context = torch.cat([context, up_ids], dim=-1)
if turn < len(responses_ids):
@@ -161,13 +140,7 @@ def get_tokenized_context(max_len):
# Stop conditions
-if mode == "llama" or mode == "codellama":
-
- generator.set_stop_conditions([tokenizer.eos_token_id])
-
-if mode == "raw":
-
- generator.set_stop_conditions([username + ":", username[0:1] + ":", username.upper() + ":", username.lower() + ":", tokenizer.eos_token_id])
+generator.set_stop_conditions(prompt_format.stop_conditions(tokenizer))
# ANSI color codes
@@ -175,6 +148,7 @@ def get_tokenized_context(max_len):
col_user = "\u001b[33;1m" # Yellow
col_bot = "\u001b[34;1m" # Blue
col_error = "\u001b[31;1m" # Magenta
+col_sysprompt = "\u001b[37;1m" # Grey
# Code block formatting
@@ -188,6 +162,11 @@ def get_tokenized_context(max_len):
# Main loop
+print(f" -- Prompt format: {args.mode}")
+print(f" -- System prompt:")
+print()
+print(col_sysprompt + system_prompt.strip() + col_default)
+
while True:
# Get user prompt
@@ -207,7 +186,7 @@ def get_tokenized_context(max_len):
# Stream response
- if mode == "raw":
+ if prompt_format.print_bot_name():
print(col_bot + botname + ": " + col_default, end = "")
@@ -288,7 +267,7 @@ def get_tokenized_context(max_len):
if eos:
- if mode == "llama" or mode == "codellama":
+ if prompt_format.print_extra_newline():
print()
break
diff --git a/examples/chat_prompts.py b/examples/chat_prompts.py
new file mode 100644
index 00000000..606e6889
--- /dev/null
+++ b/examples/chat_prompts.py
@@ -0,0 +1,177 @@
+
+class PromptFormat:
+
+ botname = "Chatbort"
+ username = "User"
+
+ def __init__(self):
+ pass
+
+ #
+
+ def default_system_prompt(self):
+ raise NotImplementedError
+
+ def first_prompt(self):
+ raise NotImplementedError
+
+ def subs_prompt(self):
+ raise NotImplementedError
+
+ def stop_conditions(self, tokenizer):
+ raise NotImplementedError
+
+ def encoding_options(self): # (add_bos, add_eos, encode_special_tokens)
+ raise NotImplementedError
+
+ def print_bot_name(self):
+ return False
+
+ def print_extra_newline(self):
+ return False
+
+
+class PromptFormat_raw(PromptFormat):
+
+ description = "Model-agnostic mode simulating a raw chatlog"
+
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def default_system_prompt(self):
+ return \
+ f"""This is a conversation between a helpful AI assistant named {self.botname} and a """ + \
+ (f"""user named {self.username}.""" if self.username != "User" else """user.""")
+
+ def first_prompt(self):
+ return \
+ f"""<|system_prompt|>\n{self.username}: <|user_prompt|>\n{self.botname}:"""
+
+ def subs_prompt(self):
+ return \
+ f"""{self.username}: <|user_prompt|>\n{self.botname}:"""
+
+ def stop_conditions(self, tokenizer):
+ return \
+ [self.username + ":",
+ self.username[0:1] + ":",
+ self.username.upper() + ":",
+ self.username.lower() + ":",
+ tokenizer.eos_token_id]
+
+ def encoding_options(self):
+ return False, False, False
+
+ def print_bot_name(self):
+ return True
+
+
+class PromptFormat_llama(PromptFormat):
+
+ description = "Llama-chat, Llama2-chat and Mistral-instruct models"
+
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def default_system_prompt(self):
+ return \
+ """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. """ + \
+ """Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. """ + \
+ """Please ensure that your responses are socially unbiased and positive in nature."""
+
+ def first_prompt(self):
+ return \
+ """[INST] <>\n<|system_prompt|>\n<>\n\n<|user_prompt|> [/INST]"""
+
+ def subs_prompt(self):
+ return \
+ """[INST] <|user_prompt|> [/INST]"""
+
+ def stop_conditions(self, tokenizer):
+ return \
+ [tokenizer.eos_token_id]
+
+ def encoding_options(self):
+ return True, False, False
+
+ def print_extra_newline(self):
+ return True
+
+
+class PromptFormat_codellama(PromptFormat_llama):
+
+ description = "CodeLlama-instruct"
+
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def default_system_prompt(self):
+ return \
+ """You are a helpful coding assistant. Always answer as helpfully as possible."""
+
+
+class PromptFormat_chatml(PromptFormat):
+
+ description = "ChatML format, as used by e.g. (Mistral)Orca"
+
+ def __init__(self):
+ super().__init__()
+ pass
+
+ def default_system_prompt(self):
+ return \
+ f"""You are {self.botname}, a large language model. Answer as concisely as possible."""
+
+ def first_prompt(self):
+ return \
+ """<|im_start|>system\n""" + \
+ """<|system_prompt|>\n""" + \
+ """<|im_end|>\n""" + \
+ """<|im_start|>user\n""" + \
+ """<|user_prompt|><|im_end|>\n""" + \
+ """<|im_start|>assistant\n"""
+
+ def subs_prompt(self):
+ return \
+ """<|im_end|>\n""" + \
+ """<|im_start|>user\n""" + \
+ """<|user_prompt|><|im_end|>\n""" + \
+ """<|im_start|>assistant\n"""
+
+ def stop_conditions(self, tokenizer):
+ return \
+ [tokenizer.eos_token_id,
+ """<|im_end|>"""]
+
+ def encoding_options(self):
+ return False, False, True
+
+ def print_extra_newline(self):
+ return True
+
+
+class PromptFormat_tinyllama(PromptFormat_chatml):
+
+ description = "ChatML format, but ignoring special/added tokens. Use for TinyLlama-chat v0.3"
+
+ def encoding_options(self):
+ return False, False, False
+
+
+prompt_formats = \
+{
+ "raw": PromptFormat_raw,
+ "llama": PromptFormat_llama,
+ "codellama": PromptFormat_codellama,
+ "chatml": PromptFormat_chatml,
+ "tinyllama": PromptFormat_tinyllama,
+}
+
+
+
+
+
+
diff --git a/exllamav2/generator/base.py b/exllamav2/generator/base.py
index 34940976..f4ca08fb 100644
--- a/exllamav2/generator/base.py
+++ b/exllamav2/generator/base.py
@@ -47,7 +47,9 @@ def generate_simple(self, prompt: str or list,
gen_settings: ExLlamaV2Sampler.Settings,
num_tokens: int,
seed = None,
- token_healing = False):
+ token_healing = False,
+ encode_special_tokens = False,
+ decode_special_tokens = False ):
# Apply seed
@@ -56,7 +58,7 @@ def generate_simple(self, prompt: str or list,
# Tokenize input and produce padding mask if needed
batch_size = 1 if isinstance(prompt, str) else len(prompt)
- ids = self.tokenizer.encode(prompt)
+ ids = self.tokenizer.encode(prompt, encode_special_tokens = encode_special_tokens)
overflow = ids.shape[-1] + num_tokens - self.model.config.max_seq_len
if overflow > 0: ids = ids[:, overflow:]
@@ -93,7 +95,7 @@ def generate_simple(self, prompt: str or list,
# Decode
- text = self.tokenizer.decode(self.sequence_ids)
+ text = self.tokenizer.decode(self.sequence_ids, decode_special_tokens = decode_special_tokens)
if isinstance(prompt, str): return text[0]
return text
diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py
index 37b13bd9..c9020162 100644
--- a/exllamav2/generator/sampler.py
+++ b/exllamav2/generator/sampler.py
@@ -38,7 +38,8 @@ def clone(self):
def disallow_tokens(self, tokenizer, tokens):
if self.token_bias is None:
- self.token_bias = torch.zeros((tokenizer.config.vocab_size,), dtype = torch.float)
+ padding = -tokenizer.config.vocab_size % 32
+ self.token_bias = torch.zeros((tokenizer.config.vocab_size + padding,), dtype = torch.float)
self.token_bias[tokens] = float("-inf")
diff --git a/exllamav2/tokenizer.py b/exllamav2/tokenizer.py
index 84e3588c..219bd69f 100644
--- a/exllamav2/tokenizer.py
+++ b/exllamav2/tokenizer.py
@@ -1,6 +1,7 @@
from exllamav2.config import ExLlamaV2Config
from sentencepiece import SentencePieceProcessor
import torch
+import os, json, re
class ExLlamaV2Tokenizer:
@@ -35,6 +36,10 @@ def __init__(self, children = None, leaf = None):
char_trie: Trie = None
char_trie_ci: Trie = None
+ extended_id_to_piece = {}
+ extended_piece_to_id = {}
+ special_delimiters = None
+
tokenized_str_cache = {}
max_cached_strings = 100
@@ -44,11 +49,45 @@ def __init__(self, config, lazy_init = False):
self.tokenizer = SentencePieceProcessor(model_file = self.config.tokenizer_path)
+ # Load added_tokens.json if present
+
+ added_tokens_path = os.path.join(self.config.model_dir, "added_tokens.json")
+ if os.path.exists(added_tokens_path):
+ with open(added_tokens_path) as f:
+ self.extended_piece_to_id = json.load(f)
+
+ self.extended_id_to_piece = { v: k for k, v in self.extended_piece_to_id.items() }
+
+ # Get control token IDs
+
+ # self.eos_token_id = self.tokenizer.eos_id()
+ # self.bos_token_id = self.tokenizer.bos_id()
+ # self.unk_token_id = config.unk_token_id
+
self.unk_token_id = self.tokenizer.unk_id()
- self.eos_token_id = self.tokenizer.eos_id()
- self.bos_token_id = self.tokenizer.bos_id()
+ self.eos_token_id = config.eos_token_id
+ self.bos_token_id = config.bos_token_id
+
+ # Get control token strings
+
+ try: self.unk_token = self.extended_id_to_piece[self.unk_token_id] or self.tokenizer.id_to_piece(self.unk_token_id)
+ except: pass
+ try: self.bos_token = self.extended_id_to_piece[self.bos_token_id] or self.tokenizer.id_to_piece(self.bos_token_id)
+ except: pass
+ try: self.eos_token = self.extended_id_to_piece[self.eos_token_id] or self.tokenizer.id_to_piece(self.eos_token_id)
+ except: pass
+
self.pad_token_id = 0
+ # Make sure extended vocab contains control tokens
+
+ self.extended_id_to_piece[self.unk_token_id] = self.unk_token
+ self.extended_id_to_piece[self.bos_token_id] = self.bos_token
+ self.extended_id_to_piece[self.eos_token_id] = self.eos_token
+ self.extended_piece_to_id[self.unk_token] = self.unk_token_id
+ self.extended_piece_to_id[self.bos_token] = self.bos_token_id
+ self.extended_piece_to_id[self.eos_token] = self.eos_token_id
+
# Create dictionaries on init
if not lazy_init:
@@ -68,17 +107,39 @@ def single_token(self, token_id: int):
return torch.tensor([[token_id]], dtype = torch.long)
+ # Encode string with special tokens
+
+ def encode_special(self, text: str):
+
+ if self.special_delimiters is None:
+ self.special_delimiters = re.compile("(" + "|".join(map(re.escape, self.extended_piece_to_id.keys())) + ")")
+
+ split = self.special_delimiters.split(text)
+ encoded = []
+
+ i = 0
+ while i < len(split):
+ if split[i] != "": encoded += self.tokenizer.EncodeAsIds(split[i])
+ if i + 1 < len(split): encoded += [self.extended_piece_to_id[split[i + 1]]]
+ i += 2
+
+ return encoded
+
+
# Encode string
# TODO: Handle added tokens for "special" models
- def encode(self, text, add_bos = False, add_eos = False):
+ def encode(self, text, add_bos = False, add_eos = False, encode_special_tokens = False):
if isinstance(text, list):
# text is a list of strings
- list_ids = self.tokenizer.EncodeAsIds(text)
+ if encode_special_tokens:
+ list_ids = [self.encode_special(t) for t in text]
+ else:
+ list_ids = self.tokenizer.EncodeAsIds(text)
if add_bos:
for ids in list_ids: ids.insert(0, self.bos_token_id)
@@ -99,7 +160,10 @@ def encode(self, text, add_bos = False, add_eos = False):
# text is a single string
- ids = self.tokenizer.EncodeAsIds(text)
+ if encode_special_tokens:
+ ids = self.encode_special(text)
+ else:
+ ids = self.tokenizer.EncodeAsIds(text)
if add_bos:
ids.insert(0, self.bos_token_id)
@@ -109,18 +173,45 @@ def encode(self, text, add_bos = False, add_eos = False):
return torch.tensor(ids).to(torch.long).unsqueeze(0)
+ # Decode sequence with or without special tokens
+
+ def decode_(self, seq, decode_special_tokens):
+
+ if not decode_special_tokens:
+
+ max_token = self.tokenizer.vocab_size()
+ seq = [t for t in seq if (t != self.pad_token_id and t < max_token)]
+ if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)]
+ return self.tokenizer.Decode(seq)
+
+ else:
+
+ text = ""
+ start = 0
+ end = 0
+ while end < len(seq):
+ if seq[end] in self.extended_id_to_piece:
+ if end > start: text += self.tokenizer.Decode(seq[start : end])
+ text += self.extended_id_to_piece[seq[end]]
+ end += 1
+ start = end
+ else:
+ end += 1
+ if end > start: text += self.tokenizer.Decode(seq[start : end])
+
+ return text
+
+
# Decode IDs
- def decode(self, ids):
+ def decode(self, ids, decode_special_tokens = False):
if ids.dim() > 1:
texts = []
for i in range(ids.shape[0]):
seq = ids[i].tolist()
- seq = [t for t in seq if t != self.pad_token_id]
- if self.eos_token_id in seq: seq = seq[:seq.index(self.eos_token_id)]
- texts.append(self.tokenizer.Decode(seq))
+ texts.append(self.decode_(seq, decode_special_tokens))
return texts
else:
@@ -159,6 +250,12 @@ def get_id_to_piece_list(self):
(p.replace("▁", " ") if not p.startswith("<") else self.tokenizer.decode(idx))
for idx, p in enumerate(self.tokenizer.id_to_piece(all_tokens))
]
+
+ i = self.tokenizer.vocab_size()
+ while i in self.extended_id_to_piece:
+ self.id_to_piece.append(self.extended_id_to_piece[i])
+ i += 1
+
return self.id_to_piece
@@ -209,6 +306,10 @@ def get_prefix_id_to_ids_dict(self):
self.prefix_id_to_ids = { piece_to_id[piece]: ids for piece, ids in prefix_to_ids.items() }
+ for i in range(self.config.vocab_size):
+ if i not in self.prefix_id_to_ids:
+ self.prefix_id_to_ids[i] = [i]
+
return self.prefix_id_to_ids
diff --git a/exllamav2/version.py b/exllamav2/version.py
index 221ce5dc..c1336bdb 100644
--- a/exllamav2/version.py
+++ b/exllamav2/version.py
@@ -1 +1 @@
-__version__ = "0.0.4"
\ No newline at end of file
+__version__ = "0.0.5"
\ No newline at end of file