Skip to content

Commit

Permalink
Fix tokenizers (#887)
Browse files Browse the repository at this point in the history
* Update pyproject.toml

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* fix_tokenizer

* Update tokenizer_utils.py

* Update tokenizer_utils.py
  • Loading branch information
danielhanchen authored Aug 7, 2024
1 parent bfe38e6 commit 8001d30
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 23 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ huggingface = [
"peft>=0.7.1,!=0.11.0",
"protobuf<4.0.0",
"huggingface_hub",
"hf-transfer",
"hf_transfer",
]
cu118only = [
"xformers==0.0.22.post7",
Expand Down Expand Up @@ -178,7 +178,7 @@ colab-new = [
"numpy",
"protobuf<4.0.0",
"huggingface_hub",
"hf-transfer",
"hf_transfer",
]
colab-no-deps = [
"accelerate>=0.26.1",
Expand Down
83 changes: 64 additions & 19 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def prepare_model_for_kbit_training(
"""

# Freeze all parameters except LoRA
import re
with torch.no_grad():
for name, param in model.named_parameters():
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
Expand Down Expand Up @@ -389,12 +388,14 @@ def patch_tokenizer(model, tokenizer):
Fixes https://github.com/unslothai/unsloth/issues/5
"""
possible_reserved_tokens = (
"<|finetune_right_pad_id|>", # Llama-3.1
"<pad>", # Mistral Nemo
"<|reserved", # Llama-3
"<|placeholder", # Phi-3
"[control", # Mistral type models
"<pad>", # Mistral Nemo
"<|finetune_right_pad_id|>", # Llama-3.1
)
joiner = "\1\0=+=\0\1"
number_repetitions = 3 - 1 # Number of reserved tokens needed

if model is not None:
model.config.update({"unsloth_version" : __version__})
Expand All @@ -412,28 +413,69 @@ def patch_tokenizer(model, tokenizer):
if bad_pad_token:
# Find a better pad token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
possible_pad_token = None
n_possible_pad_tokens = 0
for added_token in added_tokens[::-1]:
if added_token.startswith(possible_reserved_tokens):
if possible_pad_token is None: possible_pad_token = added_token
n_possible_pad_tokens += 1
# We must see at least 3 of the reserved tokens
if n_possible_pad_tokens >= 3: break
all_added_tokens = joiner.join(added_tokens[::-1])
all_added_tokens += joiner

final_pad_token = None
final_good_match = False

for possible_reserved_token in possible_reserved_tokens:
possible_reserved_token = re.escape(possible_reserved_token)
found = re.finditer(f"{possible_reserved_token}", all_added_tokens)
first_match = None
good_match = False
for j, x in enumerate(found):
if j == 0: first_match = x
if j >= number_repetitions:
good_match = True
break
pass
pass

if first_match is None: continue

# If it ends with |> or > etc, then set it as a good pad token!
start = first_match.span(0)[0]
possible_pad_token = first_match.group(0)
end = all_added_tokens.find(joiner, start)
first_match = all_added_tokens[start:end]

if first_match is not None:
good_match = possible_pad_token.endswith((">", "|>", "]", ")"))
pass
possible_pad_token = first_match

# Replace current pad token if another exact match is found
if not final_good_match and good_match:
final_good_match = True
final_pad_token = possible_pad_token
break
else:
final_good_match = False
final_pad_token = possible_pad_token
pass
pass
if n_possible_pad_tokens < 3: possible_pad_token = None
possible_pad_token = final_pad_token

if possible_pad_token is None:
# Try unk_token
# Try unk_token
if possible_pad_token is None and hasattr(tokenizer, "unk_token"):
possible_pad_token = tokenizer.unk_token
pass

# Check pad token's id must be less than vocab size
if possible_pad_token is not None:
check_pad_token = tokenizer(possible_pad_token, add_special_tokens = False).input_ids
if len(check_pad_token) != 1:
possible_pad_token = None
if check_pad_token[0] >= config.vocab_size:
possible_pad_token = None
pass

if possible_pad_token is None:
# Failure to find a good replacement!! We shall manually add one!
new_pad_token = "<|PAD_TOKEN|>"
while new_pad_token in tokenizer.get_vocab():
new_pad_token += "#"
new_pad_token = f"<{new_pad_token}>"
pass
possible_pad_token = new_pad_token
pass
Expand All @@ -447,11 +489,16 @@ def patch_tokenizer(model, tokenizer):
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
tokenizer.pad_token = possible_pad_token
if model is not None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
else:
if model is not None:
if model.config.pad_token_id is None:
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
pass
pass
model.generation_config.update(max_length = model.config.max_position_embeddings)
return model, tokenizer
pass

Expand All @@ -462,7 +509,6 @@ def patch_tokenizer(model, tokenizer):
from peft import __version__ as peft_version
if Version(peft_version) < Version("0.12.0"):
from peft.tuners.lora.layer import LoraLayer
import inspect, re
try:
source = inspect.getsource(LoraLayer.update_layer)
text = "if weight is not None:\n"
Expand Down Expand Up @@ -688,7 +734,6 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None,
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
from accelerate.utils.dataclasses import DistributedType
import re
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
Expand Down
1 change: 1 addition & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def from_pretrained(
padding_side = "right",
token = token,
trust_remote_code = trust_remote_code,
fix_tokenizer = fix_tokenizer,
)

model, tokenizer = patch_tokenizer(model, tokenizer)
Expand Down
115 changes: 113 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,14 @@ class SentencePieceTokenTypes(IntEnum):
pass


def load_correct_tokenizer(
def _load_correct_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT:
cache_dir = cache_dir
Expand Down Expand Up @@ -501,7 +502,10 @@ def load_correct_tokenizer(
cache_dir = cache_dir,
)

if tokenizer_name in IGNORED_TOKENIZER_NAMES:
if not fix_tokenizer or tokenizer_name in IGNORED_TOKENIZER_NAMES:
return fast_tokenizer
# Ignore Mistral ones - they're a bit weird to handle!
elif "mistral" in tokenizer_name.lower():
return fast_tokenizer
elif slow_tokenizer is not None:
if hasattr(fast_tokenizer, "add_bos_token") and hasattr(slow_tokenizer, "add_bos_token"):
Expand All @@ -522,6 +526,113 @@ def load_correct_tokenizer(
pass


def load_correct_tokenizer(
tokenizer_name,
model_max_length = None,
padding_side = "right",
token = None,
trust_remote_code = False,
cache_dir = "huggingface_tokenizers_cache",
fix_tokenizer = True,
):
tokenizer = _load_correct_tokenizer(
tokenizer_name = tokenizer_name,
model_max_length = model_max_length,
padding_side = padding_side,
token = token,
trust_remote_code = trust_remote_code,
cache_dir = cache_dir,
fix_tokenizer = fix_tokenizer,
)

### 1. Fixup tokenizer's chat_template
old_chat_template = getattr(tokenizer, "chat_template", None)

# Ignore mistral type models since they don't have a add_generation_prompt
if "mistral" in str(getattr(tokenizer, "name_or_path", "")).lower():
chat_template = old_chat_template

# Also check Llama-2 old style models
elif old_chat_template is not None and \
"[/INST]" in old_chat_template and "[INST]" in old_chat_template and \
"bos_token" in old_chat_template and "eos_token" in old_chat_template:

chat_template = old_chat_template

else:
chat_template = fix_chat_template(tokenizer)
if old_chat_template is not None and chat_template is None:
raise RuntimeError(
"Unsloth: Fixing chat template failed - please file a report immediately!"
)
pass
pass

tokenizer.chat_template = chat_template
return tokenizer
pass


def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
where = chat_template.find(endfor)
if where == -1: return chat_template

after_endfor = chat_template[where + len(endfor):]

if "{% if" not in after_endfor and "{% set " not in after_endfor and \
after_endfor.startswith("{{") and after_endfor.endswith("}}") and \
after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1:

after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}"

chat_template = chat_template[:where + len(endfor)] + after_endfor
pass
return chat_template
pass


def fix_chat_template(tokenizer):
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is None: return None

### 1. Check if add_generation_prompt works
messages = [
{"role": "user", "content": "Who are you?"},
]
no = tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
yes = tokenizer.apply_chat_template(messages, add_generation_prompt = True, tokenize = False)

if no == yes:
# SAME?! That's not good! We check for add_generation_prompt
if "{% if add_generation_prompt %}" not in chat_template:
# Try fixing it by adding it
new_chat_template = _fix_chat_template(chat_template)
if "{% if add_generation_prompt %}" not in new_chat_template:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
"does not have a {% if add_generation_prompt %} for generation purposes.\n"\
"Please file a bug report immediately - thanks!"
)
else:
logger.warning_once(
"Unsloth: We successfully patched the tokenizer to add a {% if add_generation_prompt %} to the chat_template.\n"\
"This is not a bug, but please notify the Unsloth maintainers - thanks!"
)
chat_template = new_chat_template
pass
else:
raise RuntimeError(
f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\
"has a {% if add_generation_prompt %} for generation purposes, but wasn't provided correctly.\n"\
"Please file a bug report immediately - thanks!"
)
pass
pass
return chat_template
pass


def check_tokenizer(
model,
tokenizer,
Expand Down

0 comments on commit 8001d30

Please sign in to comment.