diff --git a/tokenizers/examples/train_bpe.rs b/tokenizers/examples/train_bpe.rs
index 4c12951b..f080b1ed 100644
--- a/tokenizers/examples/train_bpe.rs
+++ b/tokenizers/examples/train_bpe.rs
@@ -11,33 +11,36 @@ use std::path::Path;
fn main() -> Result<()> {
let vocab_size: usize = 100;
+ let min_frequency = 0;
+ let add_prefix_space = false;
+ let trim_offsets = false;
+ let use_regex = false;
+
let mut trainer = BpeTrainerBuilder::new()
.show_progress(true)
.vocab_size(vocab_size)
- .min_frequency(0)
+ .min_frequency(min_frequency)
.special_tokens(vec![
- AddedToken::from(String::from(""), true),
- AddedToken::from(String::from(""), true),
- AddedToken::from(String::from(""), true),
- AddedToken::from(String::from(""), true),
- AddedToken::from(String::from(""), true),
+ AddedToken::from(String::from(""), true),
+ AddedToken::from(String::from(""), true),
+ AddedToken::from(String::from(""), true),
+ AddedToken::from(String::from(""), true),
+ AddedToken::from(String::from(""), true),
])
.build();
let mut tokenizer = TokenizerBuilder::new()
.with_model(BPE::default())
- .with_normalizer(Some(Sequence::new(vec![
- Strip::new(true, true).into(),
- NFC.into(),
- ])))
- .with_pre_tokenizer(Some(ByteLevel::default()))
- .with_post_processor(Some(ByteLevel::default()))
- .with_decoder(Some(ByteLevel::default()))
+ .with_normalizer(Some(Sequence::new(vec![])))
+ .with_pre_tokenizer(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex)))
+ .with_post_processor(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex)))
+ .with_decoder(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex)))
.build()?;
let pretty = false;
tokenizer
.train_from_pretokenized_data(
+ // .train_from_files(
&mut trainer,
vec!["/home/felipe_cohere_com/pretokenized.tsv".to_string()],
)?
diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs
index b53070ac..8fb81862 100644
--- a/tokenizers/src/models/bpe/trainer.rs
+++ b/tokenizers/src/models/bpe/trainer.rs
@@ -683,16 +683,19 @@ impl Trainer for BpeTrainer {
.maybe_par_bridge()
.map(|sequence| {
let split_seq = process(sequence.as_ref())?;
+ let mut count_str = split_seq.last().expect("This should work").to_string();
let segment = split_seq[0]
.replace("", "\n")
.replace("", "\t")
.replace("", "\r")
.clone();
- let mut count_str = split_seq[1].to_string();
+
let count_int: u64 = count_str.trim_end().parse().unwrap();
let mut map = HashMap::new();
- map.insert(segment, count_int);
+ for i in 0..split_seq.len() - 1 {
+ map.insert(split_seq[i].clone(), count_int);
+ }
Ok(map)
})
.reduce(
diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs
index d83b507a..065edffb 100644
--- a/tokenizers/src/models/unigram/trainer.rs
+++ b/tokenizers/src/models/unigram/trainer.rs
@@ -666,14 +666,19 @@ impl Trainer for UnigramTrainer {
.maybe_par_bridge()
.map(|sequence| {
let split_seq = process(sequence.as_ref())?;
+ let mut count_str = split_seq.last().expect("This should work").to_string();
let segment = split_seq[0]
.replace("", "\n")
.replace("", "\t")
.replace("", "\r")
- .clone(); let mut count_str = split_seq[1].to_string();
+ .clone();
+
+
let count_int: u32 = count_str.trim_end().parse().unwrap();
let mut map = HashMap::new();
- map.insert(segment, count_int);
+ for i in 0..split_seq.len() - 1 {
+ map.insert(split_seq[i].clone(), count_int);
+ }
Ok(map)
})
.reduce(
diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs
index 21ee1edf..52ed0f2d 100644
--- a/tokenizers/src/models/wordlevel/trainer.rs
+++ b/tokenizers/src/models/wordlevel/trainer.rs
@@ -135,14 +135,19 @@ impl Trainer for WordLevelTrainer {
.maybe_par_bridge()
.map(|sequence| {
let split_seq = process(sequence.as_ref())?;
+ let mut count_str = split_seq.last().expect("This should work").to_string();
let segment = split_seq[0]
.replace("", "\n")
.replace("", "\t")
.replace("", "\r")
- .clone(); let mut count_str = split_seq[1].to_string();
+ .clone();
+
+
let count_int: u64 = count_str.trim_end().parse().unwrap();
let mut map = HashMap::new();
- map.insert(segment, count_int);
+ for i in 0..split_seq.len() - 1 {
+ map.insert(split_seq[i].clone(), count_int);
+ }
Ok(map)
})
.reduce(
diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs
index 61ea10fa..a9041422 100644
--- a/tokenizers/src/tokenizer/mod.rs
+++ b/tokenizers/src/tokenizer/mod.rs
@@ -1198,7 +1198,15 @@ where
}),
|seq| {
let split_seq: Vec = seq.split('\t').map(|s| s.to_string()).collect();
- Ok(split_seq)
+ let normalized = self.do_normalize(split_seq[0].clone())?;
+ let pre_tokenized = self.do_pre_tokenize(normalized)?;
+ let mut pretokenized_and_freq: Vec<_> = pre_tokenized
+ .get_splits(OffsetReferential::Original, OffsetType::Byte)
+ .into_iter()
+ .map(|(s, _, _)| s.to_owned())
+ .collect();
+ pretokenized_and_freq.extend(vec![split_seq[1].clone()].iter().map(|s| s.to_string()));
+ Ok(pretokenized_and_freq)
},
)?;