Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastTokenizer not using the user_defined_symbols defined in the SentencePiece Model #28324

Closed
2 of 4 tasks
kitkhai opened this issue Jan 3, 2024 · 2 comments
Closed
2 of 4 tasks

Comments

@kitkhai
Copy link

kitkhai commented Jan 3, 2024

System Info

  • transformers version: 4.35.2
  • Platform: Linux-6.1.58+-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.4.1
  • Accelerate version: 0.25.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (False)
  • Tensorflow version (GPU?): 2.15.0 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
  • Jax version: 0.4.23
  • JaxLib version: 0.4.23
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: no

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers.convert_slow_tokenizer import import_protobuf
from transformers import AutoTokenizer
from transformers import NllbTokenizer, NllbTokenizerFast

checkpoint = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.save_pretrained("old_tokenizer")

model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
m.ParseFromString(open("./old_tokenizer/sentencepiece.bpe.model", 'rb').read())

piece = m.SentencePiece()
piece.piece = "superlongword"
piece.score = -10
piece.type = 4

m.pieces.extend([piece1])
with open("temp_eng_insert_user_def_sentencepiece.bpe.model", 'wb') as f:
    f.write(m.SerializeToString())


tokenizer_edited = NllbTokenizer(vocab_file="temp_sentencepiece.bpe.model", src_lang = "zho_Hans", tgt_lang = "eng_Latn")
tokenizer_edited_fast = NllbTokenizerFast(vocab_file="temp_sentencepiece.bpe.model", src_lang = "zho_Hans", tgt_lang = "eng_Latn")

sent = 'Hi there superlongword'
print(sent)
> Hi there superlongword

print("original tokenizer: ", tokenizer.tokenize(sent))
> original tokenizer:  ['▁Hi', '▁there', '▁super', 'long', 'word']

print("tokenizer with tokens: ", tokenizer_edited.tokenize(sent))
> tokenizer with tokens:  ['▁Hi', '▁there', '▁', 'superlongword']

print("tokenizer with tokens (Fast): ", tokenizer_edited_fast.tokenize(sent))
> tokenizer with tokens (Fast):  ['▁Hi', '▁there', '▁super', 'long', 'word']

Expected behavior

> Hi there superlongword
> original tokenizer:  ['▁Hi', '▁there', '▁super', 'long', 'word']
> tokenizer with tokens:  ['▁Hi', '▁there', '▁', 'superlongword']
> tokenizer with tokens (Fast):  ['▁Hi', '▁there', '▁', 'superlongword']

I faced a similar issue as raised by a question in the HF forum where the OP trainer the tokenizer with user_defined_symbols while in my case I added to the SentencePiece model file directly without training.

Noted that I can just use the add_tokens method to achieve the same outcome but because of another issue that I raised #28218 , I would like to avoid the use of add_tokens method if possible.

@kitkhai
Copy link
Author

kitkhai commented Jan 3, 2024

Additionally, is there a way to retrieve (and edit) the merge rules from "slow" & "fast" tokenizers respectively?

@ArthurZucker
Copy link
Collaborator

Hey! Few things here. What you are trying to do is outside the scope of the supported features. Adding a token should be done using tokenizer.add_tokens function.
The fast version is for me more right than what you expect. If there are no merges, then there is absolutely no reason for the BPE model to fuse '▁super', 'long', 'word' into superlongword. Thus the slow version seems more wrong, and specifically because sentencepiece does not really allow adding tokens that way.

@kitkhai kitkhai closed this as completed Jan 3, 2024
@kitkhai kitkhai closed this as not planned Won't fix, can't repro, duplicate, stale Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants