From e653aa33848d1e509e28864b0277170e6cbdb4de Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Mon, 12 Aug 2024 10:03:21 +0200 Subject: [PATCH] fix: chunking crash on chinese characters --- screenpipe-server/src/chunking.rs | 44 +++++++++++------------- screenpipe-server/tests/chunking_test.rs | 24 +++++++++++++ 2 files changed, 45 insertions(+), 23 deletions(-) create mode 100644 screenpipe-server/tests/chunking_test.rs diff --git a/screenpipe-server/src/chunking.rs b/screenpipe-server/src/chunking.rs index 649fa304..a98012af 100644 --- a/screenpipe-server/src/chunking.rs +++ b/screenpipe-server/src/chunking.rs @@ -1,14 +1,13 @@ use anyhow::Result; -use candle::{Device, Tensor, DType}; -use candle_nn::{VarBuilder, Module}; +use candle::{DType, Device, Tensor}; +use candle_nn::{Module, VarBuilder}; use candle_transformers::models::jina_bert::{BertModel, Config}; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; pub async fn text_chunking_by_similarity(text: &str) -> Result> { - let device = Device::new_metal(0) - .unwrap_or_else(|_| Device::new_cuda(0) - .unwrap_or(Device::Cpu)); + let device = + Device::new_metal(0).unwrap_or_else(|_| Device::new_cuda(0).unwrap_or(Device::Cpu)); let repo = Repo::with_revision( "jinaai/jina-embeddings-v2-base-en".to_string(), RepoType::Model, @@ -36,14 +35,17 @@ pub async fn text_chunking_by_similarity(text: &str) -> Result> { let max_chunk_length = 300; for sentence in sentences { - let tokens = tokenizer.encode(sentence, true).map_err(anyhow::Error::msg)?; + let tokens = tokenizer + .encode(sentence, true) + .map_err(anyhow::Error::msg)?; let token_ids = Tensor::new(tokens.get_ids(), &device)?; let embeddings = model.forward(&token_ids.unsqueeze(0)?)?; let sentence_embedding = embeddings.mean(1)?; let should_split = if let Some(prev_emb) = &previous_embedding { let similarity = cosine_similarity(&sentence_embedding, prev_emb)?; - similarity < similarity_threshold || current_chunk.len() + sentence.len() > max_chunk_length + similarity < similarity_threshold + || current_chunk.len() + sentence.len() > max_chunk_length } else { false }; @@ -78,22 +80,17 @@ pub fn text_chunking_simple(text: &str) -> Result> { // Chunk by fixed character count with overlap let chunk_size = 200; let overlap = 30; + let chars: Vec = text.chars().collect(); let mut start = 0; - while start < text.len() { - let end = (start + chunk_size).min(text.len()); - - // Find a valid char boundary - let end = text[start..].char_indices() - .take_while(|(i, _)| *i + start <= end) - .last() - .map(|(i, _)| start + i + 1) - .unwrap_or(text.len()); - - // Safely create the chunk - chunks.push(text[start..end].to_string()); - - start = if end == text.len() { end } else { end - overlap }; + while start < chars.len() { + let end = (start + chunk_size).min(chars.len()); + chunks.push(chars[start..end].iter().collect()); + start = if end == chars.len() { + end + } else { + end - overlap + }; } } @@ -106,6 +103,7 @@ fn cosine_similarity(a: &Tensor, b: &Tensor) -> Result { let dot_product = (&a * &b)?.sum_all()?; let norm_a = a.sqr()?.sum_all()?.sqrt()?; let norm_b = b.sqr()?.sum_all()?.sqrt()?; - let similarity = dot_product.to_scalar::()? / (norm_a.to_scalar::()? * norm_b.to_scalar::()?); + let similarity = dot_product.to_scalar::()? + / (norm_a.to_scalar::()? * norm_b.to_scalar::()?); Ok(similarity) -} \ No newline at end of file +} diff --git a/screenpipe-server/tests/chunking_test.rs b/screenpipe-server/tests/chunking_test.rs new file mode 100644 index 00000000..88a1e6f5 --- /dev/null +++ b/screenpipe-server/tests/chunking_test.rs @@ -0,0 +1,24 @@ +use anyhow::Result; +use screenpipe_server::chunking::text_chunking_simple; + +#[test] +fn test_text_chunking_with_chinese_characters() { + let chinese_text = "謝謝大家".repeat(100); // Repeat 100 times to ensure we exceed chunk size + let result = text_chunking_simple(&chinese_text); + + assert!( + result.is_ok(), + "Function should not panic with Chinese characters" + ); + + let chunks = result.unwrap(); + assert!(!chunks.is_empty(), "Should produce at least one chunk"); + + for chunk in chunks { + assert!(!chunk.is_empty(), "Each chunk should contain text"); + assert!( + chunk.chars().all(|c| c == '謝' || c == '大' || c == '家'), + "Chunks should only contain expected characters" + ); + } +}