Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cleaned configuration properties for tokenizer with serialization - improve tokenization of XLM #1092

Merged
merged 12 commits into from
Aug 30, 2019
Merged
13 changes: 10 additions & 3 deletions pytorch_transformers/modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-pytorch_model.bin",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-pytorch_model.bin",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-pytorch_model.bin",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.json",
}
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json",
Expand All @@ -54,6 +56,8 @@
'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json",
'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json",
'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json",
'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json",
}


Expand Down Expand Up @@ -114,6 +118,7 @@ def __init__(self,
causal=False,
asm=False,
n_langs=1,
use_lang_emb=True,
max_position_embeddings=512,
embed_init_std=2048 ** -0.5,
layer_norm_eps=1e-12,
Expand Down Expand Up @@ -157,6 +162,7 @@ def __init__(self,
self.causal = causal
self.asm = asm
self.n_langs = n_langs
self.use_lang_emb = use_lang_emb
self.layer_norm_eps = layer_norm_eps
self.bos_index = bos_index
self.eos_index = eos_index
Expand Down Expand Up @@ -488,7 +494,7 @@ class XLMModel(XLMPreTrainedModel):

"""
ATTRIBUTES = ['encoder', 'eos_index', 'pad_index', # 'with_output',
'n_langs', 'n_words', 'dim', 'n_layers', 'n_heads',
'n_langs', 'use_lang_emb', 'n_words', 'dim', 'n_layers', 'n_heads',
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value']

Expand All @@ -507,6 +513,7 @@ def __init__(self, config): #, dico, is_encoder, with_output):

# dictionary / languages
self.n_langs = config.n_langs
self.use_lang_emb = config.use_lang_emb
self.n_words = config.n_words
self.eos_index = config.eos_index
self.pad_index = config.pad_index
Expand All @@ -529,7 +536,7 @@ def __init__(self, config): #, dico, is_encoder, with_output):
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1:
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -628,7 +635,7 @@ def forward(self, input_ids, lengths=None, position_ids=None, langs=None,
# embeddings
tensor = self.embeddings(input_ids)
tensor = tensor + self.position_embeddings(position_ids).expand_as(tensor)
if langs is not None:
if langs is not None and self.use_lang_emb:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/tests/tokenization_bert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def setUp(self):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

def get_tokenizer(self):
return self.tokenizer_class.from_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/tests/tokenization_dilbert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):

tokenizer_class = DistilBertTokenizer

def get_tokenizer(self):
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def test_sequence_builders(self):
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
Expand Down
5 changes: 3 additions & 2 deletions pytorch_transformers/tests/tokenization_gpt2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def setUp(self):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))

def get_tokenizer(self):
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return GPT2Tokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"lower newer"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/tests/tokenization_openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def setUp(self):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))

def get_tokenizer(self):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"lower newer"
Expand Down
5 changes: 3 additions & 2 deletions pytorch_transformers/tests/tokenization_roberta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def setUp(self):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))

def get_tokenizer(self):
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"lower newer"
Expand Down
17 changes: 13 additions & 4 deletions pytorch_transformers/tests/tokenization_tests_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,32 @@ def setUp(self):
def tearDown(self):
shutil.rmtree(self.tmpdirname)

def get_tokenizer(self):
def get_tokenizer(self, **kwargs):
raise NotImplementedError

def get_input_output_texts(self):
raise NotImplementedError

def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer()
self.assertNotEqual(tokenizer.max_len, 42)

# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)

before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")

with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
tokenizer = tokenizer.from_pretrained(tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)

after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens)

after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname, max_len=43)
self.assertEqual(tokenizer.max_len, 43)

def test_pickle_tokenizer(self):
tokenizer = self.get_tokenizer()
Expand Down
5 changes: 3 additions & 2 deletions pytorch_transformers/tests/tokenization_transfo_xl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def setUp(self):
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))

def get_tokenizer(self):
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, lower_case=True)
def get_tokenizer(self, **kwargs):
kwargs['lower_case'] = True
return TransfoXLTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"<unk> UNwanted , running"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/tests/tokenization_xlm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def setUp(self):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))

def get_tokenizer(self):
return XLMTokenizer.from_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"lower newer"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_transformers/tests/tokenization_xlnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def setUp(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)

def get_tokenizer(self):
return XLNetTokenizer.from_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)

def get_input_output_texts(self):
input_text = u"This is a test"
Expand Down
36 changes: 18 additions & 18 deletions pytorch_transformers/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@
'bert-base-cased-finetuned-mrpc': 512,
}

PRETRAINED_INIT_CONFIGURATION = {
'bert-base-uncased': {'do_lower_case': True},
'bert-large-uncased': {'do_lower_case': True},
'bert-base-cased': {'do_lower_case': False},
'bert-large-cased': {'do_lower_case': False},
'bert-base-multilingual-uncased': {'do_lower_case': True},
'bert-base-multilingual-cased': {'do_lower_case': False},
'bert-base-chinese': {'do_lower_case': False},
'bert-base-german-cased': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking': {'do_lower_case': True},
'bert-large-cased-whole-word-masking': {'do_lower_case': False},
'bert-large-uncased-whole-word-masking-finetuned-squad': {'do_lower_case': True},
'bert-large-cased-whole-word-masking-finetuned-squad': {'do_lower_case': False},
'bert-base-cased-finetuned-mrpc': {'do_lower_case': False},
}


def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
Expand Down Expand Up @@ -100,6 +117,7 @@ class BertTokenizer(PreTrainedTokenizer):

vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
Expand Down Expand Up @@ -202,24 +220,6 @@ def save_vocabulary(self, vocab_path):
index += 1
return (vocab_file,)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" Instantiate a BertTokenizer from pre-trained vocabulary files.
"""
if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is a cased model but you have not set "
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
"you may want to check this behavior.")
kwargs['do_lower_case'] = False
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
logger.warning("The pre-trained model you are loading is an uncased model but you have set "
"`do_lower_case` to False. We are setting `do_lower_case=True` for you "
"but you may want to check this behavior.")
kwargs['do_lower_case'] = True

return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)


class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
Expand Down
3 changes: 2 additions & 1 deletion pytorch_transformers/tokenization_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def __init__(self, special=None, min_freq=0, max_size=None, lower_case=False,
# in a library like ours, at all.
vocab_dict = torch.load(pretrained_vocab_file)
for key, value in vocab_dict.items():
self.__dict__[key] = value
if key not in self.__dict__:
self.__dict__[key] = value

if vocab_file is not None:
self.build_vocab()
Expand Down
Loading