diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 217c37e90..4d8ffbd3a 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -356,8 +356,12 @@ impl BPE { /// Resize the cache pub fn resize_cache(&mut self, capacity: usize) { - if let Some(ref mut cache) = self.cache { + if capacity == 0 { + self.cache = None; + } else if let Some(cache) = self.cache.as_mut() { cache.resize(capacity); + } else { + self.cache = Some(Cache::new(capacity)); } } diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index da4d631ce..261ffff2c 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -18,7 +18,7 @@ type Vocab = Vec<(String, f64)>; pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, - cache: Cache>, + cache: Option>>, trie: Trie, pub min_score: f64, pub(super) unk_id: Option, @@ -39,7 +39,7 @@ impl Clone for Unigram { // `Clone` can't be derive because it's not implemented for `Cache`. // To keep things simple when we clone, the new Unigram will start with a fresh cache. fn clone(&self) -> Self { - let fresh_cache = self.cache.fresh(); + let fresh_cache = self.cache.as_ref().map(|cache| cache.fresh()); Self { vocab: self.vocab.clone(), cache: fresh_cache, @@ -134,7 +134,7 @@ impl Unigram { eos_id, unk_id, fuse_unk, - cache: Cache::default(), + cache: Some(Cache::default()), is_optimized, byte_fallback, }) @@ -143,7 +143,7 @@ impl Unigram { #[cfg(test)] pub(super) fn set_fuse_unk(&mut self, fuse_unk: bool) { self.fuse_unk = fuse_unk; - self.cache = self.cache.fresh(); + self.cache = self.cache.as_ref().map(|cache| cache.fresh()); } #[cfg(test)] @@ -222,19 +222,20 @@ impl Unigram { if sentence.is_empty() { return Ok(vec![]); } - if let Some(result) = self.cache.get(sentence) { - Ok(result.to_vec()) + if let Some(result) = self.cache.as_ref().and_then(|cache| cache.get(sentence)) { + return Ok(result.to_vec()); + } + let result = if self.is_optimized { + self.encode_optimized(sentence)? } else { - let result = if self.is_optimized { - self.encode_optimized(sentence)? - } else { - self.encode_unoptimized(sentence)? - }; - if sentence.len() < MAX_LENGTH { - self.cache.set(sentence.to_owned(), result.clone()); + self.encode_unoptimized(sentence)? + }; + if sentence.len() < MAX_LENGTH { + if let Some(cache) = self.cache.as_ref() { + cache.set(sentence.to_owned(), result.clone()); } - Ok(result) } + Ok(result) } fn encode_optimized(&self, sentence: &str) -> Result> { @@ -382,12 +383,20 @@ impl Unigram { /// Clears the internal cache pub fn clear_cache(&mut self) { - self.cache.clear(); + if let Some(cache) = self.cache.as_ref() { + cache.clear(); + } } /// Resize the cache pub fn resize_cache(&mut self, capacity: usize) { - self.cache.resize(capacity); + if capacity == 0 { + self.cache = None; + } else if let Some(cache) = self.cache.as_mut() { + cache.resize(capacity); + } else { + self.cache = Some(Cache::new(capacity)); + } } } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 808d120d5..f866c4904 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -602,6 +602,11 @@ where &self.model } + /// Get a mutable reference to the model + pub fn get_model_mut(&mut self) -> &mut M { + &mut self.model + } + /// Set the added vocabulary. pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { self.added_vocabulary = added_vocabulary;