diff --git a/tokenizers/examples/train_bpe.rs b/tokenizers/examples/train_bpe.rs index 4c12951b..f080b1ed 100644 --- a/tokenizers/examples/train_bpe.rs +++ b/tokenizers/examples/train_bpe.rs @@ -11,33 +11,36 @@ use std::path::Path; fn main() -> Result<()> { let vocab_size: usize = 100; + let min_frequency = 0; + let add_prefix_space = false; + let trim_offsets = false; + let use_regex = false; + let mut trainer = BpeTrainerBuilder::new() .show_progress(true) .vocab_size(vocab_size) - .min_frequency(0) + .min_frequency(min_frequency) .special_tokens(vec![ - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), - AddedToken::from(String::from(""), true), + AddedToken::from(String::from(""), true), + AddedToken::from(String::from(""), true), + AddedToken::from(String::from(""), true), + AddedToken::from(String::from(""), true), + AddedToken::from(String::from(""), true), ]) .build(); let mut tokenizer = TokenizerBuilder::new() .with_model(BPE::default()) - .with_normalizer(Some(Sequence::new(vec![ - Strip::new(true, true).into(), - NFC.into(), - ]))) - .with_pre_tokenizer(Some(ByteLevel::default())) - .with_post_processor(Some(ByteLevel::default())) - .with_decoder(Some(ByteLevel::default())) + .with_normalizer(Some(Sequence::new(vec![]))) + .with_pre_tokenizer(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex))) + .with_post_processor(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex))) + .with_decoder(Some(ByteLevel::new(add_prefix_space, trim_offsets, use_regex))) .build()?; let pretty = false; tokenizer .train_from_pretokenized_data( + // .train_from_files( &mut trainer, vec!["/home/felipe_cohere_com/pretokenized.tsv".to_string()], )? diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index b53070ac..8fb81862 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -683,16 +683,19 @@ impl Trainer for BpeTrainer { .maybe_par_bridge() .map(|sequence| { let split_seq = process(sequence.as_ref())?; + let mut count_str = split_seq.last().expect("This should work").to_string(); let segment = split_seq[0] .replace("", "\n") .replace("", "\t") .replace("", "\r") .clone(); - let mut count_str = split_seq[1].to_string(); + let count_int: u64 = count_str.trim_end().parse().unwrap(); let mut map = HashMap::new(); - map.insert(segment, count_int); + for i in 0..split_seq.len() - 1 { + map.insert(split_seq[i].clone(), count_int); + } Ok(map) }) .reduce( diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index d83b507a..065edffb 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -666,14 +666,19 @@ impl Trainer for UnigramTrainer { .maybe_par_bridge() .map(|sequence| { let split_seq = process(sequence.as_ref())?; + let mut count_str = split_seq.last().expect("This should work").to_string(); let segment = split_seq[0] .replace("", "\n") .replace("", "\t") .replace("", "\r") - .clone(); let mut count_str = split_seq[1].to_string(); + .clone(); + + let count_int: u32 = count_str.trim_end().parse().unwrap(); let mut map = HashMap::new(); - map.insert(segment, count_int); + for i in 0..split_seq.len() - 1 { + map.insert(split_seq[i].clone(), count_int); + } Ok(map) }) .reduce( diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index 21ee1edf..52ed0f2d 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -135,14 +135,19 @@ impl Trainer for WordLevelTrainer { .maybe_par_bridge() .map(|sequence| { let split_seq = process(sequence.as_ref())?; + let mut count_str = split_seq.last().expect("This should work").to_string(); let segment = split_seq[0] .replace("", "\n") .replace("", "\t") .replace("", "\r") - .clone(); let mut count_str = split_seq[1].to_string(); + .clone(); + + let count_int: u64 = count_str.trim_end().parse().unwrap(); let mut map = HashMap::new(); - map.insert(segment, count_int); + for i in 0..split_seq.len() - 1 { + map.insert(split_seq[i].clone(), count_int); + } Ok(map) }) .reduce( diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 61ea10fa..a9041422 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1198,7 +1198,15 @@ where }), |seq| { let split_seq: Vec = seq.split('\t').map(|s| s.to_string()).collect(); - Ok(split_seq) + let normalized = self.do_normalize(split_seq[0].clone())?; + let pre_tokenized = self.do_pre_tokenize(normalized)?; + let mut pretokenized_and_freq: Vec<_> = pre_tokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, _, _)| s.to_owned()) + .collect(); + pretokenized_and_freq.extend(vec![split_seq[1].clone()].iter().map(|s| s.to_string())); + Ok(pretokenized_and_freq) }, )?;