Skip to content

Commit 39a2e80

Browse files
committed
Cleanup.
1 parent 0c7ab4f commit 39a2e80

37 files changed

+626
-783
lines changed

tokenizers/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ fn main() -> Result<()> {
9595
.vocab_size(vocab_size)
9696
.min_frequency(0)
9797
.special_tokens(vec![
98-
AddedToken::from(String::from("<s>"), true),
99-
AddedToken::from(String::from("<pad>"), true),
100-
AddedToken::from(String::from("</s>"), true),
101-
AddedToken::from(String::from("<unk>"), true),
102-
AddedToken::from(String::from("<mask>"), true),
98+
AddedToken::from("<s>", true),
99+
AddedToken::from("<pad>", true),
100+
AddedToken::from("</s>", true),
101+
AddedToken::from("<unk>", true),
102+
AddedToken::from("<mask>", true),
103103
])
104104
.build();
105105

tokenizers/benches/bert_benchmark.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ fn create_bert_tokenizer(wp: WordPiece) -> BertTokenizer {
3838
tokenizer.with_normalizer(Some(BertNormalizer::default()));
3939
tokenizer.with_decoder(Some(decoders::wordpiece::WordPiece::default()));
4040
tokenizer.with_post_processor(Some(BertProcessing::new(
41-
("[SEP]".into(), sep_id),
42-
("[CLS]".into(), cls_id),
41+
("[SEP]", sep_id),
42+
("[CLS]", cls_id),
4343
)));
4444
tokenizer
4545
}

tokenizers/src/decoders/bpe.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@ use serde::{Deserialize, Serialize};
99
#[serde(tag = "type")]
1010
#[non_exhaustive]
1111
pub struct BPEDecoder {
12-
pub suffix: String,
12+
pub suffix: CompactString,
1313
}
1414

1515
impl BPEDecoder {
16-
pub fn new(suffix: String) -> Self {
17-
Self { suffix }
16+
pub fn new(suffix: impl Into<CompactString>) -> Self {
17+
Self {
18+
suffix: suffix.into(),
19+
}
1820
}
1921
}
2022

2123
impl Default for BPEDecoder {
2224
fn default() -> Self {
23-
Self::new("</w>".into())
25+
Self::new("</w>")
2426
}
2527
}
2628

@@ -37,7 +39,7 @@ impl Decoder for BPEDecoder {
3739
let replacement = if i == n { "" } else { " " };
3840
token
3941
.to_compact_string()
40-
.replace(&self.suffix, replacement)
42+
.replace(&*self.suffix, replacement)
4143
.to_compact_string()
4244
})
4345
.collect::<Vec<CompactString>>())

tokenizers/src/decoders/byte_fallback.rs

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,49 +78,41 @@ mod tests {
7878
#[test]
7979
fn decode() {
8080
let decoder = ByteFallback::new();
81-
let res = decoder
82-
.decode_chain(vec!["Hey".to_owned(), "friend!".to_owned()])
83-
.unwrap();
81+
let res = decoder.decode_chain(vec!["Hey", "friend!"]).unwrap();
8482
assert_eq!(
8583
res.into_iter()
8684
.map(|t| t.to_compact_string())
8785
.collect::<Vec<_>>(),
88-
vec!["Hey".to_owned(), "friend!".to_owned()]
86+
vec!["Hey", "friend!"]
8987
);
9088

91-
let res = decoder.decode_chain(vec!["<0x61>".to_owned()]).unwrap();
89+
let res = decoder.decode_chain(vec!["<0x61>"]).unwrap();
9290
assert_eq!(
9391
res.into_iter()
9492
.map(|t| t.to_compact_string())
9593
.collect::<Vec<_>>(),
96-
vec!["a".to_owned()]
94+
vec!["a"]
9795
);
9896

99-
let res = decoder.decode_chain(vec!["<0xE5>".to_owned()]).unwrap();
97+
let res = decoder.decode_chain(vec!["<0xE5>"]).unwrap();
10098
assert_eq!(
10199
res.into_iter()
102100
.map(|t| t.to_compact_string())
103101
.collect::<Vec<_>>(),
104102
vec!["�"]
105103
);
106104

107-
let res = decoder
108-
.decode_chain(vec!["<0xE5>".to_owned(), "<0x8f>".to_owned()])
109-
.unwrap();
105+
let res = decoder.decode_chain(vec!["<0xE5>", "<0x8f>"]).unwrap();
110106
assert_eq!(
111107
res.into_iter()
112108
.map(|t| t.to_compact_string())
113109
.collect::<Vec<_>>(),
114-
vec!["�".to_owned(), "�".to_owned()]
110+
vec!["�", "�"]
115111
);
116112

117113
// 叫
118114
let res = decoder
119-
.decode_chain(vec![
120-
"<0xE5>".to_owned(),
121-
"<0x8f>".to_owned(),
122-
"<0xab>".to_owned(),
123-
])
115+
.decode_chain(vec!["<0xE5>", "<0x8f>", "<0xab>"])
124116
.unwrap();
125117
assert_eq!(
126118
res.into_iter()
@@ -130,32 +122,21 @@ mod tests {
130122
);
131123

132124
let res = decoder
133-
.decode_chain(vec![
134-
"<0xE5>".to_owned(),
135-
"<0x8f>".to_owned(),
136-
"<0xab>".to_owned(),
137-
"a".to_owned(),
138-
])
125+
.decode_chain(vec!["<0xE5>", "<0x8f>", "<0xab>", "a"])
139126
.unwrap();
140127
assert_eq!(
141128
res.into_iter()
142129
.map(|t| t.to_compact_string())
143130
.collect::<Vec<_>>(),
144-
vec!["叫".to_owned(), "a".to_owned()]
131+
vec!["叫", "a"]
145132
);
146133

147-
let res = decoder
148-
.decode_chain(vec![
149-
"<0xE5>".to_owned(),
150-
"<0x8f>".to_owned(),
151-
"a".to_owned(),
152-
])
153-
.unwrap();
134+
let res = decoder.decode_chain(vec!["<0xE5>", "<0x8f>", "a"]).unwrap();
154135
assert_eq!(
155136
res.into_iter()
156137
.map(|t| t.to_compact_string())
157138
.collect::<Vec<_>>(),
158-
vec!["�".to_owned(), "�".to_owned(), "a".to_owned()]
139+
vec!["�", "�", "a"]
159140
);
160141
}
161142
}

tokenizers/src/decoders/ctc.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,31 @@ use serde::{Deserialize, Serialize};
1414
#[non_exhaustive]
1515
pub struct CTC {
1616
/// The pad token used by CTC to delimit a new token.
17-
pub pad_token: String,
17+
pub pad_token: CompactString,
1818
/// The word delimiter token. It will be replaced by a `<space>`.
19-
pub word_delimiter_token: String,
19+
pub word_delimiter_token: CompactString,
2020
/// Whether to cleanup some tokenization artifacts.
2121
/// Mainly spaces before punctuation, and some abbreviated english forms.
2222
pub cleanup: bool,
2323
}
2424

2525
impl CTC {
26-
pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
26+
pub fn new(
27+
pad_token: impl Into<CompactString>,
28+
word_delimiter_token: impl Into<CompactString>,
29+
cleanup: bool,
30+
) -> Self {
2731
Self {
28-
pad_token,
29-
word_delimiter_token,
32+
pad_token: pad_token.into(),
33+
word_delimiter_token: word_delimiter_token.into(),
3034
cleanup,
3135
}
3236
}
3337
}
3438

3539
impl Default for CTC {
3640
fn default() -> Self {
37-
Self {
38-
pad_token: "<pad>".to_string(),
39-
word_delimiter_token: "|".to_string(),
40-
cleanup: true,
41-
}
41+
Self::new("<pad>", "|", true)
4242
}
4343
}
4444

@@ -52,10 +52,10 @@ impl Decoder for CTC {
5252
.map(|token| token.to_compact_string())
5353
.dedup()
5454
.filter_map(|token| {
55-
let mut replaced: CompactString = token.replace(&self.pad_token, "").into();
55+
let mut replaced: CompactString = token.replace(&*self.pad_token, "").into();
5656
if self.cleanup {
5757
replaced = wordpiece::cleanup(&replaced)
58-
.replace(&self.word_delimiter_token, " ")
58+
.replace(&*self.word_delimiter_token, " ")
5959
.into();
6060
}
6161
if replaced.is_empty() {

tokenizers/src/decoders/fuse.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ mod tests {
4444
#[test]
4545
fn decode() {
4646
let decoder = Fuse::new();
47-
let res = decoder
48-
.decode_chain(vec!["Hey".to_owned(), " friend!".to_owned()])
49-
.unwrap();
47+
let res = decoder.decode_chain(vec!["Hey", " friend!"]).unwrap();
5048
assert_eq!(
5149
res.into_iter()
5250
.map(|t| t.to_compact_string())

tokenizers/src/decoders/strip.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,7 @@ mod tests {
7171
fn decode() {
7272
let decoder = Strip::new('H', 1, 0);
7373
let res = decoder
74-
.decode_chain(vec![
75-
"Hey".to_owned(),
76-
" friend!".to_owned(),
77-
"HHH".to_owned(),
78-
])
74+
.decode_chain(vec!["Hey", " friend!", "HHH"])
7975
.unwrap();
8076
assert_eq!(
8177
res.into_iter()
@@ -85,9 +81,7 @@ mod tests {
8581
);
8682

8783
let decoder = Strip::new('y', 0, 1);
88-
let res = decoder
89-
.decode_chain(vec!["Hey".to_owned(), " friend!".to_owned()])
90-
.unwrap();
84+
let res = decoder.decode_chain(vec!["Hey", " friend!"]).unwrap();
9185
assert_eq!(
9286
res.into_iter()
9387
.map(|t| t.to_compact_string())

tokenizers/src/decoders/wordpiece.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,7 @@ mod tests {
8282

8383
assert_eq!(
8484
decoder
85-
.decode(vec![
86-
"##uelo".to_owned(),
87-
"Ara".to_owned(),
88-
"##új".to_owned(),
89-
"##o".to_owned(),
90-
"No".to_owned(),
91-
"##guera".to_owned()
92-
])
85+
.decode(vec!["##uelo", "Ara", "##új", "##o", "No", "##guera"])
9386
.unwrap()
9487
.to_compact_string(),
9588
"##uelo Araújo Noguera"

tokenizers/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@
8383
//! .vocab_size(vocab_size)
8484
//! .min_frequency(0)
8585
//! .special_tokens(vec![
86-
//! AddedToken::from(String::from("<s>"), true),
87-
//! AddedToken::from(String::from("<pad>"), true),
88-
//! AddedToken::from(String::from("</s>"), true),
89-
//! AddedToken::from(String::from("<unk>"), true),
90-
//! AddedToken::from(String::from("<mask>"), true),
86+
//! AddedToken::from("<s>", true),
87+
//! AddedToken::from("<pad>", true),
88+
//! AddedToken::from("</s>", true),
89+
//! AddedToken::from("<unk>", true),
90+
//! AddedToken::from("<mask>", true),
9191
//! ])
9292
//! .build();
9393
//!

tokenizers/src/models/bpe/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
22
use std::{iter, mem};
3-
use compact_str::CompactString;
43

54
mod model;
65
mod serialization;
@@ -27,10 +26,10 @@ pub enum Error {
2726
BadMerges(usize),
2827
/// If a token found in merges, is not in the vocab
2928
#[error("Token `{0}` out of vocabulary")]
30-
MergeTokenOutOfVocabulary(CompactString),
29+
MergeTokenOutOfVocabulary(String),
3130
/// If the provided unk token is out of vocabulary
3231
#[error("Unk token `{0}` not found in the vocabulary")]
33-
UnkTokenOutOfVocabulary(CompactString),
32+
UnkTokenOutOfVocabulary(String),
3433
/// Dropout not between 0 and 1.
3534
#[error("Dropout should be between 0 and 1, inclusive")]
3635
InvalidDropout,

0 commit comments

Comments
 (0)