diff --git a/src/chonkie/__init__.py b/src/chonkie/__init__.py index d69ee29..7756d17 100644 --- a/src/chonkie/__init__.py +++ b/src/chonkie/__init__.py @@ -1,7 +1,5 @@ """Main package for Chonkie.""" -from .context import Context - from .chunker import ( BaseChunker, Chunk, @@ -15,6 +13,7 @@ TokenChunker, WordChunker, ) +from .context import Context from .embeddings import ( AutoEmbeddings, BaseEmbeddings, @@ -22,7 +21,6 @@ OpenAIEmbeddings, SentenceTransformerEmbeddings, ) - from .refinery import ( BaseRefinery, OverlapRefinery, diff --git a/src/chonkie/chunker/base.py b/src/chonkie/chunker/base.py index af0fe3a..2790f5b 100644 --- a/src/chonkie/chunker/base.py +++ b/src/chonkie/chunker/base.py @@ -1,7 +1,6 @@ """Base classes for chunking text.""" import importlib - import inspect import warnings from abc import ABC, abstractmethod @@ -58,7 +57,7 @@ def __repr__(self) -> str: def __iter__(self): """Return an iterator over the chunk.""" return iter(self.text) - + def __getitem__(self, index: int): """Return the item at the given index.""" return self.text[index] @@ -73,7 +72,6 @@ def copy(self) -> "Chunk": ) - class BaseChunker(ABC): """Abstract base class for all chunker implementations. @@ -81,7 +79,9 @@ class BaseChunker(ABC): the chunk() method according to their specific chunking strategy. """ - def __init__(self, tokenizer_or_token_counter: Union[str, Any, Callable[[str], int]]): + def __init__( + self, tokenizer_or_token_counter: Union[str, Any, Callable[[str], int]] + ): """Initialize the chunker with a tokenizer. Args: @@ -116,7 +116,6 @@ def _get_tokenizer_backend(self): f"Tokenizer backend {str(type(self.tokenizer))} not supported" ) - def _load_tokenizer(self, tokenizer_name: str): """Load a tokenizer based on the backend.""" try: @@ -186,7 +185,6 @@ def _encode(self, text: str): f"Tokenizer backend {self._tokenizer_backend} not supported." ) - def _encode_batch(self, texts: List[str]): """Encode a batch of texts using the backend tokenizer.""" if self._tokenizer_backend == "transformers": @@ -205,7 +203,6 @@ def _encode_batch(self, texts: List[str]): f"Tokenizer backend {self._tokenizer_backend} not supported." ) - def _decode(self, tokens) -> str: """Decode tokens using the backend tokenizer.""" if self._tokenizer_backend == "transformers": @@ -219,7 +216,6 @@ def _decode(self, tokens) -> str: f"Tokenizer backend {self._tokenizer_backend} not supported." ) - def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: """Decode a batch of token lists using the backend tokenizer.""" if self._tokenizer_backend == "transformers": @@ -233,7 +229,6 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]: f"Tokenizer backend {self._tokenizer_backend} not supported." ) - def _count_tokens(self, text: str) -> int: """Count tokens in text using the backend tokenizer.""" return self.token_counter(text) diff --git a/src/chonkie/chunker/semantic.py b/src/chonkie/chunker/semantic.py index 025bea6..f005038 100644 --- a/src/chonkie/chunker/semantic.py +++ b/src/chonkie/chunker/semantic.py @@ -41,7 +41,7 @@ class SemanticChunk(SentenceChunk): sentences: List of SemanticSentence objects in the chunk """ - + sentences: List[SemanticSentence] = field(default_factory=list) @@ -128,9 +128,10 @@ def __init__( # Probably the dependency is not installed if self.embedding_model is None: - raise ImportError("embedding_model is not a valid embedding model", - "Please install the `semantic` extra to use this feature") - + raise ImportError( + "embedding_model is not a valid embedding model", + "Please install the `semantic` extra to use this feature", + ) # Keeping the tokenizer the same as the sentence model is important # for the group semantic meaning to be calculated properly diff --git a/src/chonkie/chunker/sentence.py b/src/chonkie/chunker/sentence.py index 174044a..7028a80 100644 --- a/src/chonkie/chunker/sentence.py +++ b/src/chonkie/chunker/sentence.py @@ -1,6 +1,5 @@ from bisect import bisect_left from dataclasses import dataclass, field - from itertools import accumulate from typing import Any, List, Union @@ -26,6 +25,7 @@ class Sentence: end_index: int token_count: int + @dataclass class SentenceChunk(Chunk): """Dataclass representing a sentence chunk with metadata. @@ -40,9 +40,11 @@ class SentenceChunk(Chunk): sentences: List of Sentence objects in the chunk """ + # Don't redeclare inherited fields sentences: List[Sentence] = field(default_factory=list) + class SentenceChunker(BaseChunker): """SentenceChunker splits the sentences in a text based on token limits and sentence boundaries. @@ -238,18 +240,20 @@ def _get_token_counts(self, sentences: List[str]) -> List[int]: def _estimate_token_counts(self, text: str) -> int: """Estimate token count using character length.""" - CHARS_PER_TOKEN = 6. # Avg. char per token for llama3 is b/w 6-7 + CHARS_PER_TOKEN = 6.0 # Avg. char per token for llama3 is b/w 6-7 if type(text) is str: return max(1, int(len(text) / CHARS_PER_TOKEN)) elif type(text) is list and type(text[0]) is str: return [max(1, int(len(t) / CHARS_PER_TOKEN)) for t in text] else: - raise ValueError(f"Unknown type passed to _estimate_token_count: {type(text)}") - + raise ValueError( + f"Unknown type passed to _estimate_token_count: {type(text)}" + ) + def _get_feedback(self, estimate: int, actual: int) -> float: """Validate against the actual token counts and correct the estimates.""" - feedback = 1 - ((estimate-actual)/estimate) - return feedback + feedback = 1 - ((estimate - actual) / estimate) + return feedback def _prepare_sentences(self, text: str) -> List[Sentence]: """Split text into sentences and calculate token counts for each sentence. @@ -261,67 +265,62 @@ def _prepare_sentences(self, text: str) -> List[Sentence]: List of Sentence objects """ - # Split text into sentences + # Split text into sentences sentence_texts = self._split_sentences(text) if not sentence_texts: return [] - + # Calculate positions once positions = [] current_pos = 0 for sent in sentence_texts: positions.append(current_pos) current_pos += len(sent) + 1 # +1 for space/separator - + if not self.use_approximate: # Get accurate token counts in batch token_counts = self._get_token_counts(sentence_texts) else: # Estimate token counts using character length token_counts = self._estimate_token_counts(sentence_texts) - + # Create sentence objects return [ Sentence( - text=sent, - start_index=pos, - end_index=pos + len(sent), - token_count=count + text=sent, start_index=pos, end_index=pos + len(sent), token_count=count ) for sent, pos, count in zip(sentence_texts, positions, token_counts) ] - + def _prepare_sentences(self, text: str) -> List[Sentence]: """Prepare sentences with either estimated or accurate token counts.""" # Split text into sentences sentence_texts = self._split_sentences(text) if not sentence_texts: return [] - + # Calculate positions once positions = [] current_pos = 0 for sent in sentence_texts: positions.append(current_pos) current_pos += len(sent) + 1 # +1 for space/separator - + if not self.use_approximate: # Get accurate token counts in batch token_counts = self._get_token_counts(sentence_texts) else: # Estimate token counts using character length token_counts = self._estimate_token_counts(sentence_texts) - + # Create sentence objects return [ Sentence( - text=sent, - start_index=pos, - end_index=pos + len(sent), - token_count=count + text=sent, start_index=pos, end_index=pos + len(sent), token_count=count ) for sent, pos, count in zip(sentence_texts, positions, token_counts) ] + def _create_chunk(self, sentences: List[Sentence], token_count: int) -> Chunk: """Create a chunk from a list of sentences. @@ -354,41 +353,43 @@ def chunk(self, text: str) -> List[Chunk]: """ if not text.strip(): return [] - + # Get prepared sentences with token counts - sentences = self._prepare_sentences(text) #28mus + sentences = self._prepare_sentences(text) # 28mus if not sentences: return [] - + # Pre-calculate cumulative token counts for bisect # Add 1 token for spaces between sentences - token_sums = list(accumulate( - [s.token_count for s in sentences], lambda a,b: a + b, initial=0 - )) - + token_sums = list( + accumulate( + [s.token_count for s in sentences], lambda a, b: a + b, initial=0 + ) + ) + chunks = [] - feedback = 1. + feedback = 1.0 pos = 0 - + while pos < len(sentences): # use updated feedback on the token sums token_sums = [int(s * feedback) for s in token_sums] - + # Use bisect_left to find initial split point target_tokens = token_sums[pos] + self.chunk_size - split_idx = bisect_left(token_sums, target_tokens) -1 + split_idx = bisect_left(token_sums, target_tokens) - 1 split_idx = min(split_idx, len(sentences)) - + # Ensure we include at least one sentence beyond pos split_idx = max(split_idx, pos + 1) - + # Handle minimum sentences requirement if split_idx - pos < self.min_sentences_per_chunk: split_idx = pos + self.min_sentences_per_chunk # Get the estimated token count - estimate = token_sums[split_idx] - token_sums[pos] - + estimate = token_sums[split_idx] - token_sums[pos] + # Get candidate sentences and verify actual token count chunk_sentences = sentences[pos:split_idx] chunk_text = " ".join(s.text for s in chunk_sentences) @@ -397,23 +398,25 @@ def chunk(self, text: str) -> List[Chunk]: # Given the actual token_count and the estimate, get a feedback value for the next loop feedback = self._get_feedback(estimate, actual) # print(f"Estimate: {estimate} Actual: {actual} feedback: {feedback}") - + # Back off one sentence at a time if we exceeded chunk size - while (actual > self.chunk_size and - len(chunk_sentences) > self.min_sentences_per_chunk): + while ( + actual > self.chunk_size + and len(chunk_sentences) > self.min_sentences_per_chunk + ): split_idx -= 1 chunk_sentences = sentences[pos:split_idx] chunk_text = " ".join(s.text for s in chunk_sentences) actual = len(self._encode(chunk_text)) - + chunks.append(self._create_chunk(chunk_sentences, actual)) - + # Calculate next position with overlap if self.chunk_overlap > 0 and split_idx < len(sentences): # Calculate how many sentences we need for overlap overlap_tokens = 0 overlap_idx = split_idx - 1 - + while overlap_idx > pos and overlap_tokens < self.chunk_overlap: sent = sentences[overlap_idx] next_tokens = overlap_tokens + sent.token_count + 1 # +1 for space @@ -421,12 +424,12 @@ def chunk(self, text: str) -> List[Chunk]: break overlap_tokens = next_tokens overlap_idx -= 1 - + # Move position to after the overlap pos = overlap_idx + 1 else: pos = split_idx - + return chunks def __repr__(self) -> str: diff --git a/src/chonkie/context.py b/src/chonkie/context.py index d532b72..ba5ebb6 100644 --- a/src/chonkie/context.py +++ b/src/chonkie/context.py @@ -1,11 +1,11 @@ """Context class for storing contextual information for chunk refinement. This class is used to store contextual information for chunk refinement. -It can represent context that comes before a chunk at the moment. +It can represent context that comes before a chunk at the moment. By default, the context has no start and end indices, meaning it is not bound to any specific text. The start and end indices are only set if the -context is part of the same text as the chunk. +context is part of the same text as the chunk. """ from dataclasses import dataclass @@ -15,17 +15,17 @@ @dataclass class Context: """A dataclass representing contextual information for chunk refinement. - + This class stores text and token count information that can be used to add context to chunks during the refinement process. It can represent context that comes before or after a chunk. - + Attributes: text (str): The context text token_count (int): Number of tokens in the context text start_index (Optional[int]): Starting position of context in original text end_index (Optional[int]): Ending position of context in original text - + Example: context = Context( text="This is some context.", @@ -40,7 +40,7 @@ class Context: token_count: int start_index: Optional[int] = None end_index: Optional[int] = None - + def __post_init__(self): """Validate the Context attributes after initialization.""" if not isinstance(self.text, str): @@ -49,19 +49,24 @@ def __post_init__(self): raise ValueError("token_count must be an integer") if self.token_count < 0: raise ValueError("token_count must be non-negative") - if (self.start_index is not None and self.end_index is not None and - self.start_index > self.end_index): + if ( + self.start_index is not None + and self.end_index is not None + and self.start_index > self.end_index + ): raise ValueError("start_index must be less than or equal to end_index") - + def __len__(self) -> int: """Return the length of the context text.""" return len(self.text) - + def __str__(self) -> str: """Return a string representation of the Context.""" return self.text - + def __repr__(self) -> str: """Return a detailed string representation of the Context.""" - return (f"Context(text='{self.text}', token_count={self.token_count}, " - f"start_index={self.start_index}, end_index={self.end_index})") + return ( + f"Context(text='{self.text}', token_count={self.token_count}, " + f"start_index={self.start_index}, end_index={self.end_index})" + ) diff --git a/src/chonkie/refinery/overlap.py b/src/chonkie/refinery/overlap.py index 94583e3..b9c7128 100644 --- a/src/chonkie/refinery/overlap.py +++ b/src/chonkie/refinery/overlap.py @@ -1,4 +1,5 @@ """Refinery class which adds overlap as context to chunks.""" + from typing import Any, List, Optional from chonkie.chunker import Chunk, SemanticChunk, SentenceChunk diff --git a/tests/chunker/test_sdpm_chunker.py b/tests/chunker/test_sdpm_chunker.py index 1ef0b7a..8ad4d98 100644 --- a/tests/chunker/test_sdpm_chunker.py +++ b/tests/chunker/test_sdpm_chunker.py @@ -140,6 +140,7 @@ def test_spdm_chunker_repr(embedding_model): ) assert repr(chunker) == expected + def test_spdm_chunker_percentile_mode(embedding_model, sample_complex_markdown_text): """Test the SPDMChunker works with percentile-based similarity.""" chunker = SDPMChunker( diff --git a/tests/refinery/test_overlap_refinery.py b/tests/refinery/test_overlap_refinery.py index b0371cd..f06092e 100644 --- a/tests/refinery/test_overlap_refinery.py +++ b/tests/refinery/test_overlap_refinery.py @@ -1,17 +1,19 @@ -import pytest from typing import List -from dataclasses import dataclass + +import pytest from transformers import AutoTokenizer -from chonkie.chunker import Chunk, SentenceChunk, SemanticChunk, Sentence -from chonkie.refinery import OverlapRefinery +from chonkie.chunker import Chunk, Sentence, SentenceChunk from chonkie.context import Context +from chonkie.refinery import OverlapRefinery + @pytest.fixture def tokenizer(): """Fixture providing a GPT-2 tokenizer for testing.""" return AutoTokenizer.from_pretrained("gpt2") + @pytest.fixture def basic_chunks() -> List[Chunk]: """Fixture providing a list of basic Chunks for testing.""" @@ -20,32 +22,33 @@ def basic_chunks() -> List[Chunk]: text="This is the first chunk of text.", start_index=0, end_index=30, - token_count=8 + token_count=8, ), Chunk( text="This is the second chunk of text.", start_index=31, end_index=62, - token_count=8 + token_count=8, ), Chunk( text="This is the third chunk of text.", start_index=63, end_index=93, - token_count=8 - ) + token_count=8, + ), ] + @pytest.fixture def sentence_chunks() -> List[SentenceChunk]: """Fixture providing a list of SentenceChunks for testing.""" sentences1 = [ Sentence(text="First sentence.", start_index=0, end_index=14, token_count=3), - Sentence(text="Second sentence.", start_index=15, end_index=30, token_count=3) + Sentence(text="Second sentence.", start_index=15, end_index=30, token_count=3), ] sentences2 = [ Sentence(text="Third sentence.", start_index=31, end_index=45, token_count=3), - Sentence(text="Fourth sentence.", start_index=46, end_index=62, token_count=3) + Sentence(text="Fourth sentence.", start_index=46, end_index=62, token_count=3), ] return [ SentenceChunk( @@ -53,17 +56,18 @@ def sentence_chunks() -> List[SentenceChunk]: start_index=0, end_index=30, token_count=6, - sentences=sentences1 + sentences=sentences1, ), SentenceChunk( text="Third sentence. Fourth sentence.", start_index=31, end_index=62, token_count=6, - sentences=sentences2 - ) + sentences=sentences2, + ), ] + def test_overlap_refinery_initialization(): """Test that OverlapRefinery initializes correctly with different parameters.""" # Test default initialization @@ -71,27 +75,26 @@ def test_overlap_refinery_initialization(): assert refinery.context_size == 128 assert refinery.merge_context is True assert refinery.approximate is True - assert not hasattr(refinery, 'tokenizer') + assert not hasattr(refinery, "tokenizer") # Test initialization with tokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") refinery = OverlapRefinery( - context_size=64, - tokenizer=tokenizer, - merge_context=False, - approximate=False + context_size=64, tokenizer=tokenizer, merge_context=False, approximate=False ) assert refinery.context_size == 64 assert refinery.merge_context is False assert refinery.approximate is False - assert hasattr(refinery, 'tokenizer') + assert hasattr(refinery, "tokenizer") assert refinery.tokenizer == tokenizer + def test_overlap_refinery_empty_input(): """Test that OverlapRefinery handles empty input correctly.""" refinery = OverlapRefinery() assert refinery.refine([]) == [] + def test_overlap_refinery_single_chunk(): """Test that OverlapRefinery handles single chunk input correctly.""" refinery = OverlapRefinery() @@ -100,29 +103,27 @@ def test_overlap_refinery_single_chunk(): assert len(refined) == 1 assert refined[0].context is None + def test_overlap_refinery_basic_chunks_approximate(basic_chunks): """Test approximate overlap calculation with basic Chunks.""" refinery = OverlapRefinery(context_size=4) # Small context for testing refined = refinery.refine(basic_chunks) - + # First chunk should have no context assert refined[0].context is None - + # Subsequent chunks should have context from previous chunks for i in range(1, len(refined)): assert refined[i].context is not None assert isinstance(refined[i].context, Context) assert refined[i].context.token_count <= 4 + def test_overlap_refinery_basic_chunks_exact(basic_chunks, tokenizer): """Test exact overlap calculation with basic Chunks using tokenizer.""" - refinery = OverlapRefinery( - context_size=4, - tokenizer=tokenizer, - approximate=False - ) + refinery = OverlapRefinery(context_size=4, tokenizer=tokenizer, approximate=False) refined = refinery.refine(basic_chunks) - + # Check context for subsequent chunks for i in range(1, len(refined)): assert refined[i].context is not None @@ -131,60 +132,62 @@ def test_overlap_refinery_basic_chunks_exact(basic_chunks, tokenizer): actual_tokens = len(tokenizer.encode(refined[i].context.text)) assert actual_tokens <= 4 + def test_overlap_refinery_sentence_chunks(sentence_chunks): """Test overlap calculation with SentenceChunks.""" refinery = OverlapRefinery(context_size=4) refined = refinery.refine(sentence_chunks) - + # Check context for second chunk assert refined[1].context is not None assert isinstance(refined[1].context, Context) assert refined[1].context.token_count <= 4 + def test_overlap_refinery_no_merge_context(basic_chunks): """Test behavior when merge_context is False.""" refinery = OverlapRefinery(context_size=4, merge_context=False) refined = refinery.refine(basic_chunks) - + # Chunks should maintain original text for i in range(len(refined)): assert refined[i].text == basic_chunks[i].text assert refined[i].token_count == basic_chunks[i].token_count + def test_overlap_refinery_context_size_limits(basic_chunks): """Test that context size limits are respected.""" refinery = OverlapRefinery(context_size=2) # Very small context refined = refinery.refine(basic_chunks) - + # Check that no context exceeds size limit for chunk in refined[1:]: # Skip first chunk assert chunk.context.token_count <= 2 + def test_overlap_refinery_merge_context(basic_chunks, tokenizer): """Test merging context into chunk text.""" refinery = OverlapRefinery( - context_size=4, - tokenizer=tokenizer, - merge_context=True, - approximate=False + context_size=4, tokenizer=tokenizer, merge_context=True, approximate=False ) - + # Create a deep copy to preserve originals chunks_copy = [ Chunk( text=chunk.text, start_index=chunk.start_index, end_index=chunk.end_index, - token_count=chunk.token_count - ) for chunk in basic_chunks + token_count=chunk.token_count, + ) + for chunk in basic_chunks ] - + refined = refinery.refine(chunks_copy) - + # First chunk should be unchanged assert refined[0].text == basic_chunks[0].text assert refined[0].token_count == basic_chunks[0].token_count - + # Subsequent chunks should have context prepended for i in range(1, len(refined)): assert refined[i].context is not None @@ -194,28 +197,25 @@ def test_overlap_refinery_merge_context(basic_chunks, tokenizer): new_tokens = len(tokenizer.encode(refined[i].text)) assert new_tokens > original_tokens + def test_overlap_refinery_mixed_chunk_types(): """Test that refinery raises error for mixed chunk types.""" # Create chunks of different types chunks = [ - Chunk( - text="Basic chunk.", - start_index=0, - end_index=12, - token_count=3 - ), + Chunk(text="Basic chunk.", start_index=0, end_index=12, token_count=3), SentenceChunk( text="Sentence chunk.", start_index=13, end_index=27, token_count=3, - sentences=[] - ) + sentences=[], + ), ] - + refinery = OverlapRefinery() with pytest.raises(ValueError, match="All chunks must be of the same type"): refinery.refine(chunks) + if __name__ == "__main__": - pytest.main() \ No newline at end of file + pytest.main()