Skip to content

Commit 1bb9884

Browse files
author
Kaito Sugimoto
authored
Fixing the vocab size of the trained Unigram model (#952)
* Fixing the vocab size of the trained Unigram model * add test for the vocab size of the trained Unigram model * Revert "add test for the vocab size of the trained Unigram model" This reverts commit fb8955c. * Fixing the vocab size of the trained Unigram model * format codes * get the position of vocab-size calculation out of loop
1 parent daa4dd2 commit 1bb9884

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

bindings/python/tests/bindings/test_trainers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,28 @@ def test_train_with_special_tokens(self):
238238
"[SEP]",
239239
]
240240

241+
tokenizer = Tokenizer(models.Unigram())
242+
trainer = trainers.UnigramTrainer(
243+
show_progress=False,
244+
special_tokens=["[PAD]", "[SEP]", "[CLS]"],
245+
unk_token="[UNK]",
246+
vocab_size=100,
247+
)
248+
tokenizer.train([filename], trainer=trainer)
249+
250+
assert tokenizer.get_vocab_size() == 100
251+
252+
tokenizer = Tokenizer(models.Unigram())
253+
trainer = trainers.UnigramTrainer(
254+
show_progress=False,
255+
special_tokens=["[PAD]", "[SEP]", "[CLS]", "[UNK]"],
256+
unk_token="[UNK]",
257+
vocab_size=100,
258+
)
259+
tokenizer.train([filename], trainer=trainer)
260+
261+
assert tokenizer.get_vocab_size() == 100
262+
241263
def test_cannot_train_different_model(self):
242264
tokenizer = Tokenizer(models.BPE())
243265
trainer = trainers.UnigramTrainer(show_progress=False)

tokenizers/src/models/unigram/trainer.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,7 @@ impl UnigramTrainer {
126126
min_score_penalty += min_score_penalty_delta;
127127
}
128128
}
129-
for (token, score) in model.iter() {
130-
if inserted.contains::<str>(token) {
131-
continue;
132-
}
133-
inserted.insert(token.to_string());
134-
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
135-
if pieces.len() == self.vocab_size as usize {
136-
break;
137-
}
138-
}
139-
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
140129

141-
// Insert the necessary tokens
142130
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
143131
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
144132
if t.content == *unk {
@@ -154,6 +142,26 @@ impl UnigramTrainer {
154142
} else {
155143
(None, false)
156144
};
145+
146+
let vocab_size_without_special_tokens = if need_add_unk {
147+
self.vocab_size as usize - self.special_tokens.len() - 1
148+
} else {
149+
self.vocab_size as usize - self.special_tokens.len()
150+
};
151+
for (token, score) in model.iter() {
152+
if inserted.contains::<str>(token) {
153+
continue;
154+
}
155+
inserted.insert(token.to_string());
156+
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
157+
158+
if pieces.len() == vocab_size_without_special_tokens {
159+
break;
160+
}
161+
}
162+
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
163+
164+
// Insert the necessary tokens
157165
let mut special_tokens = self
158166
.special_tokens
159167
.iter()

0 commit comments

Comments
 (0)