Skip to content

Commit

Permalink
Fixing the vocab size of the trained Unigram model (#952)
Browse files Browse the repository at this point in the history
* Fixing the vocab size of the trained Unigram model

* add test for the vocab size of the trained Unigram model

* Revert "add test for the vocab size of the trained Unigram model"

This reverts commit fb8955c.

* Fixing the vocab size of the trained Unigram model

* format codes

* get the position of vocab-size calculation out of loop
  • Loading branch information
Kaito Sugimoto authored Mar 18, 2022
1 parent daa4dd2 commit 1bb9884
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
22 changes: 22 additions & 0 deletions bindings/python/tests/bindings/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,28 @@ def test_train_with_special_tokens(self):
"[SEP]",
]

tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False,
special_tokens=["[PAD]", "[SEP]", "[CLS]"],
unk_token="[UNK]",
vocab_size=100,
)
tokenizer.train([filename], trainer=trainer)

assert tokenizer.get_vocab_size() == 100

tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False,
special_tokens=["[PAD]", "[SEP]", "[CLS]", "[UNK]"],
unk_token="[UNK]",
vocab_size=100,
)
tokenizer.train([filename], trainer=trainer)

assert tokenizer.get_vocab_size() == 100

def test_cannot_train_different_model(self):
tokenizer = Tokenizer(models.BPE())
trainer = trainers.UnigramTrainer(show_progress=False)
Expand Down
32 changes: 20 additions & 12 deletions tokenizers/src/models/unigram/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,7 @@ impl UnigramTrainer {
min_score_penalty += min_score_penalty_delta;
}
}
for (token, score) in model.iter() {
if inserted.contains::<str>(token) {
continue;
}
inserted.insert(token.to_string());
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
if pieces.len() == self.vocab_size as usize {
break;
}
}
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());

// Insert the necessary tokens
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
if t.content == *unk {
Expand All @@ -154,6 +142,26 @@ impl UnigramTrainer {
} else {
(None, false)
};

let vocab_size_without_special_tokens = if need_add_unk {
self.vocab_size as usize - self.special_tokens.len() - 1
} else {
self.vocab_size as usize - self.special_tokens.len()
};
for (token, score) in model.iter() {
if inserted.contains::<str>(token) {
continue;
}
inserted.insert(token.to_string());
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));

if pieces.len() == vocab_size_without_special_tokens {
break;
}
}
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());

// Insert the necessary tokens
let mut special_tokens = self
.special_tokens
.iter()
Expand Down

0 comments on commit 1bb9884

Please sign in to comment.