diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 1a78586e2..a2da414c0 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -45,23 +45,19 @@ pub fn cleanup(dirty_input: &str) -> String { impl Decoder for WordPiece { fn decode_chain(&self, mut tokens: Vec) -> Result> { - tokens - .iter_mut() - .enumerate() - .map(|(i, token)| { - if i != 0 { - if token.starts_with(&self.prefix) { - *token = token.replacen(&self.prefix, "", 1); - } else { - *token = format!(" {token}"); - } + for (i, token) in tokens.iter_mut().enumerate() { + if i != 0 { + if let Some(tk) = token.strip_prefix(&self.prefix) { + *token = tk.to_string(); + } else { + *token = format!(" {token}"); } - if self.cleanup { - *token = cleanup(token); - } - Ok(token.to_string()) - }) - .collect::>() + } + if self.cleanup { + *token = cleanup(token); + } + } + Ok(tokens) } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..05689d407 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -460,7 +460,7 @@ impl BPE { Ok(word) } - fn word_to_tokens<'a, 'b: 'a>(&'a self, word: &'b Word) -> impl Iterator + 'a { + fn word_to_tokens<'a>(&'a self, word: &'a Word) -> impl Iterator + 'a { word.get_chars_iter() .zip(word.get_offsets_iter()) .map(move |(id, offsets)| Token::new(id, self.vocab_r[&id].clone(), offsets)) @@ -471,7 +471,7 @@ impl BPE { if let Some(id) = self.vocab.get(sequence) { return Ok(vec![Token::new( *id, - sequence.to_string().clone(), + sequence.to_string(), (0, sequence.len()), )]); } diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index a1a0aba76..955a5a865 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -271,19 +271,13 @@ impl BpeTrainer { let mut alphabet: HashMap = HashMap::new(); for (word, count) in wc { for c in word.chars() { - alphabet - .entry(c) - .and_modify(|cnt| *cnt += *count as usize) - .or_insert(*count as usize); + *alphabet.entry(c).or_default() += *count as usize; } } // Also include anything from the provided initial alphabet for c in &self.initial_alphabet { - alphabet - .entry(*c) - .and_modify(|cnt| *cnt = usize::MAX) - .or_insert(usize::MAX); + *alphabet.entry(*c).or_default() = usize::MAX; } let mut kept = alphabet.iter().collect::>(); @@ -293,13 +287,7 @@ impl BpeTrainer { // will be removed let to_remove = self .limit_alphabet - .map(|limit| { - if alphabet.len() > limit { - alphabet.len() - limit - } else { - 0 - } - }) + .map(|limit| alphabet.len().saturating_sub(limit)) .unwrap_or(0); // Remove the unwanted chars @@ -309,7 +297,7 @@ impl BpeTrainer { } // Keep the initial alphabet (sorted for determinism) - kept.sort_unstable_by_key(|k| (*k.0) as u32); + kept.sort_unstable_by_key(|k| *k.0 as u32); kept.into_iter().for_each(|(c, _)| { let s = c.to_string(); if !w2id.contains_key(&s) { @@ -342,13 +330,13 @@ impl BpeTrainer { // Add the `continuing_subword_prefix` if relevant if !is_first { if let Some(prefix) = &self.continuing_subword_prefix { - s = format!("{prefix}{s}"); + s.insert_str(0, prefix); } } // Add the `end_of_word_suffix` if relevant if is_last { if let Some(suffix) = &self.end_of_word_suffix { - s = format!("{s}{suffix}"); + s.push_str(suffix); } } @@ -387,23 +375,9 @@ impl BpeTrainer { let cur_pair: Pair = (window[0], window[1]); // Initialize pair_counts and where_to_update for this pair if we just saw it - if !pair_counts.contains_key(&cur_pair) { - pair_counts.insert(cur_pair, 0); - } - // Then update counts - let count = counts[i]; - where_to_update - .entry(cur_pair) - .and_modify(|h| { - h.insert(i); - }) - .or_insert_with(|| { - let mut h = HashSet::new(); - h.insert(i); - h - }); - *pair_counts.get_mut(&cur_pair).unwrap() += count as i32; + *pair_counts.entry(cur_pair).or_default() += counts[i] as i32; + where_to_update.entry(cur_pair).or_default().insert(i); } if let Some(p) = &p { @@ -416,13 +390,10 @@ impl BpeTrainer { || (HashMap::new(), HashMap::new()), |(mut pair_counts, mut where_to_update), (pc, wtu)| { for (k, v) in pc { - pair_counts.entry(k).and_modify(|c| *c += v).or_insert(v); + *pair_counts.entry(k).or_default() += v; } for (k, v) in wtu { - where_to_update - .entry(k) - .and_modify(|set| *set = set.union(&v).copied().collect()) - .or_insert(v); + where_to_update.entry(k).or_default().extend(v); } (pair_counts, where_to_update) }, @@ -488,11 +459,10 @@ impl BpeTrainer { break; } - if queue.is_empty() { + let Some(mut top) = queue.pop() else { break; - } + }; - let mut top = queue.pop().unwrap(); if top.count != pair_counts[&top.pair] as u64 { top.count = pair_counts[&top.pair] as u64; queue.push(top); @@ -504,13 +474,12 @@ impl BpeTrainer { } let part_a = &id_to_word[top.pair.0 as usize]; - let mut part_b = id_to_word[top.pair.1 as usize].to_owned(); + let mut part_b = id_to_word[top.pair.1 as usize].as_str(); // Build new token if let Some(prefix) = &self.continuing_subword_prefix { - if part_b.starts_with(prefix) { - let prefix_byte_len = prefix.chars().map(|c| c.len_utf8()).sum(); - part_b = part_b[prefix_byte_len..].to_string(); + if let Some(rest) = part_b.strip_prefix(prefix) { + part_b = rest; } } let new_token = format!("{part_a}{part_b}"); @@ -566,21 +535,9 @@ impl BpeTrainer { // Introduce new formed pairs for ((pair, change), iw) in changes { let count = change * counts[iw] as i32; - pair_counts - .entry(pair) - .and_modify(|c| *c += count) - .or_insert(count); + *pair_counts.entry(pair).or_default() += count; if change > 0 { - where_to_update - .entry(pair) - .and_modify(|h| { - h.insert(iw); - }) - .or_insert_with(|| { - let mut h = HashSet::new(); - h.insert(iw); - h - }); + where_to_update.entry(pair).or_default().insert(iw); } } where_to_update.drain().for_each(|(pair, pos)| { @@ -613,16 +570,8 @@ impl BpeTrainer { .map(|(i, (pair, new_token_id))| (pair, (i as u32, new_token_id))) .collect(); - if let Some(prefix) = &self.continuing_subword_prefix { - model.continuing_subword_prefix = Some(prefix.to_owned()); - } else { - model.continuing_subword_prefix = None; - } - if let Some(suffix) = &self.end_of_word_suffix { - model.end_of_word_suffix = Some(suffix.to_owned()); - } else { - model.end_of_word_suffix = None; - } + model.continuing_subword_prefix = self.continuing_subword_prefix.clone(); + model.end_of_word_suffix = self.end_of_word_suffix.clone(); Ok(self.special_tokens.clone()) } @@ -653,7 +602,7 @@ impl Trainer for BpeTrainer { let words = process(sequence.as_ref())?; let mut map = HashMap::new(); for word in words { - map.entry(word).and_modify(|c| *c += 1).or_insert(1); + *map.entry(word).or_default() += 1; } Ok(map) }) @@ -662,7 +611,7 @@ impl Trainer for BpeTrainer { |acc, ws| { let mut acc = acc?; for (k, v) in ws? { - acc.entry(k).and_modify(|c| *c += v).or_insert(v); + *acc.entry(k).or_default() += v; } Ok(acc) }, diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index da4d631ce..ae5b46ccc 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -115,8 +115,7 @@ impl Unigram { let mut min_score = f64::INFINITY; for (id, (token, score)) in vocab.iter().enumerate() { token_to_ids.insert(token.to_string(), id as u32); - let bytes: Vec = token.bytes().collect(); - builder.push(&bytes); + builder.push(token.as_bytes()); if score < &min_score { min_score = *score; } @@ -308,22 +307,15 @@ impl Unigram { while ends_at > 0 { let node = &best_path_ends_at[ends_at]; let starts_at = node.starts_at.unwrap(); - if self.fuse_unk - && self.unk_id.is_some() - && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? - { - token.push( - String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), - ); + if self.fuse_unk && Some(node.id) == self.unk_id { + token.push(sentence[starts_at..ends_at].to_string()); } else { if !token.is_empty() { token.reverse(); results.push(token.concat()); token = vec![]; } - results.push( - String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), - ); + results.push(sentence[starts_at..ends_at].to_string()); } ends_at = starts_at; } @@ -350,7 +342,7 @@ impl Unigram { results.push(token); token = String::new(); } - results.push(item.to_string()); + results.push(item); } } if !token.is_empty() { diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5d178e77b..1b005964a 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -428,7 +428,7 @@ impl UnigramTrainer { new_pieces.push(pieces[id].clone()); } - new_pieces.to_vec() + new_pieces } /// Update the progress bar with the new provided length and message @@ -637,7 +637,7 @@ impl Trainer for UnigramTrainer { let words = process(sequence.as_ref())?; let mut map = HashMap::new(); for word in words { - map.entry(word).and_modify(|c| *c += 1).or_insert(1); + *map.entry(word).or_default() += 1; } Ok(map) }) @@ -646,7 +646,7 @@ impl Trainer for UnigramTrainer { |acc, ws| { let mut acc = acc?; for (k, v) in ws? { - acc.entry(k).and_modify(|c| *c += v).or_insert(v); + *acc.entry(k).or_default() += v; } Ok(acc) }, diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index c52ad08d7..91a4e03b1 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -106,7 +106,7 @@ impl Trainer for WordLevelTrainer { let words = process(sequence.as_ref())?; let mut map = HashMap::new(); for word in words { - map.entry(word).and_modify(|c| *c += 1).or_insert(1); + *map.entry(word).or_default() += 1; } Ok(map) }) @@ -115,7 +115,7 @@ impl Trainer for WordLevelTrainer { |acc, ws| { let mut acc = acc?; for (k, v) in ws? { - acc.entry(k).and_modify(|c| *c += v).or_insert(v); + *acc.entry(k).or_default() += v; } Ok(acc) }, diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index f988477be..04288c668 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -296,15 +296,12 @@ impl AddedVocabulary { ) }; // Make sure we modify the previous entry - self.added_tokens_map + *self + .added_tokens_map .entry(token.content.clone()) - .and_modify(|old_id| *old_id = new_id) - .or_insert_with(|| new_id); + .or_default() = new_id; // Update the current revert operation - self.added_tokens_map_r - .entry(new_id) - .and_modify(|t| *t = token.clone()) - .or_insert_with(|| token.clone()); + *self.added_tokens_map_r.entry(new_id).or_default() = token.clone(); // Make sure to remove previous entry (if the token gets a new id) // Finally add the token to the classic set if special