Skip to content

Commit

Permalink
fix: chunking crash on chinese characters
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Aug 12, 2024
1 parent 8c6bbe0 commit e653aa3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
44 changes: 21 additions & 23 deletions screenpipe-server/src/chunking.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use anyhow::Result;
use candle::{Device, Tensor, DType};
use candle_nn::{VarBuilder, Module};
use candle::{DType, Device, Tensor};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::jina_bert::{BertModel, Config};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

pub async fn text_chunking_by_similarity(text: &str) -> Result<Vec<String>> {
let device = Device::new_metal(0)
.unwrap_or_else(|_| Device::new_cuda(0)
.unwrap_or(Device::Cpu));
let device =
Device::new_metal(0).unwrap_or_else(|_| Device::new_cuda(0).unwrap_or(Device::Cpu));
let repo = Repo::with_revision(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
Expand Down Expand Up @@ -36,14 +35,17 @@ pub async fn text_chunking_by_similarity(text: &str) -> Result<Vec<String>> {
let max_chunk_length = 300;

for sentence in sentences {
let tokens = tokenizer.encode(sentence, true).map_err(anyhow::Error::msg)?;
let tokens = tokenizer
.encode(sentence, true)
.map_err(anyhow::Error::msg)?;
let token_ids = Tensor::new(tokens.get_ids(), &device)?;
let embeddings = model.forward(&token_ids.unsqueeze(0)?)?;
let sentence_embedding = embeddings.mean(1)?;

let should_split = if let Some(prev_emb) = &previous_embedding {
let similarity = cosine_similarity(&sentence_embedding, prev_emb)?;
similarity < similarity_threshold || current_chunk.len() + sentence.len() > max_chunk_length
similarity < similarity_threshold
|| current_chunk.len() + sentence.len() > max_chunk_length
} else {
false
};
Expand Down Expand Up @@ -78,22 +80,17 @@ pub fn text_chunking_simple(text: &str) -> Result<Vec<String>> {
// Chunk by fixed character count with overlap
let chunk_size = 200;
let overlap = 30;
let chars: Vec<char> = text.chars().collect();
let mut start = 0;

while start < text.len() {
let end = (start + chunk_size).min(text.len());

// Find a valid char boundary
let end = text[start..].char_indices()
.take_while(|(i, _)| *i + start <= end)
.last()
.map(|(i, _)| start + i + 1)
.unwrap_or(text.len());

// Safely create the chunk
chunks.push(text[start..end].to_string());

start = if end == text.len() { end } else { end - overlap };
while start < chars.len() {
let end = (start + chunk_size).min(chars.len());
chunks.push(chars[start..end].iter().collect());
start = if end == chars.len() {
end
} else {
end - overlap
};
}
}

Expand All @@ -106,6 +103,7 @@ fn cosine_similarity(a: &Tensor, b: &Tensor) -> Result<f32> {
let dot_product = (&a * &b)?.sum_all()?;
let norm_a = a.sqr()?.sum_all()?.sqrt()?;
let norm_b = b.sqr()?.sum_all()?.sqrt()?;
let similarity = dot_product.to_scalar::<f32>()? / (norm_a.to_scalar::<f32>()? * norm_b.to_scalar::<f32>()?);
let similarity = dot_product.to_scalar::<f32>()?
/ (norm_a.to_scalar::<f32>()? * norm_b.to_scalar::<f32>()?);
Ok(similarity)
}
}
24 changes: 24 additions & 0 deletions screenpipe-server/tests/chunking_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use anyhow::Result;
use screenpipe_server::chunking::text_chunking_simple;

#[test]
fn test_text_chunking_with_chinese_characters() {
let chinese_text = "謝謝大家".repeat(100); // Repeat 100 times to ensure we exceed chunk size
let result = text_chunking_simple(&chinese_text);

assert!(
result.is_ok(),
"Function should not panic with Chinese characters"
);

let chunks = result.unwrap();
assert!(!chunks.is_empty(), "Should produce at least one chunk");

for chunk in chunks {
assert!(!chunk.is_empty(), "Each chunk should contain text");
assert!(
chunk.chars().all(|c| c == '謝' || c == '大' || c == '家'),
"Chunks should only contain expected characters"
);
}
}

0 comments on commit e653aa3

Please sign in to comment.