Skip to content

Commit

Permalink
Enhance SemanticChunker with error handling and similarity threshold …
Browse files Browse the repository at this point in the history
…updates

- Added error handling for missing embedding model, prompting installation of the `semantic` extra.
- Updated similarity threshold assignment to use the instance variable consistently.
- Introduced a new test for SDPMChunker to validate functionality with percentile-based similarity, ensuring proper chunking behavior and attributes.
  • Loading branch information
bhavnicksm committed Dec 4, 2024
1 parent fb37573 commit 719e33b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def __init__(
raise ValueError(
"embedding_model must be a string or BaseEmbeddings instance"
)

# Probably the dependency is not installed
if self.embedding_model is None:
raise ImportError("embedding_model is not a valid embedding model",
"Please install the `semantic` extra to use this feature")

# Keeping the tokenizer the same as the sentence model is important
# for the group semantic meaning to be calculated properly
Expand Down Expand Up @@ -264,11 +269,11 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
)
for i in range(len(sentences) - 1)
]
similarity_threshold = float(
self.similarity_threshold = float(
np.percentile(all_similarities, self.similarity_percentile)
)
else:
similarity_threshold = self.similarity_threshold
self.similarity_threshold = self.similarity_threshold

groups = []
current_group = sentences[: self.initial_sentences]
Expand All @@ -280,7 +285,7 @@ def _group_sentences(self, sentences: List[Sentence]) -> List[List[Sentence]]:
current_embedding, sentence.embedding
)

if similarity >= similarity_threshold:
if similarity >= self.similarity_threshold:
# Add to current group
current_group.append(sentence)
# Update mean embedding
Expand Down
21 changes: 21 additions & 0 deletions src/chonkie/token_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Factory class for creating and managing tokenizers.
This factory class is used to create and manage tokenizers for the Chonkie
package. It provides a simple interface for initializing, encoding, decoding,
and counting tokens using different tokenizer backends.
This is used in the Chunker and Refinery classes to ensure consistent tokenization
across different parts of the pipeline.
"""

from typing import Callable, List, TYPE_CHECKING


if TYPE_CHECKING:
import tiktoken
from transformers import AutoTokenizer
from tokenizers import Tokenizer

class TokenFactory:
"""Factory class for creating and managing tokenizers."""
pass
18 changes: 18 additions & 0 deletions tests/chunker/test_sdpm_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ def test_spdm_chunker_repr(embedding_model):
)
assert repr(chunker) == expected

def test_spdm_chunker_percentile_mode(embedding_model, sample_complex_markdown_text):
"""Test the SPDMChunker works with percentile-based similarity."""
chunker = SDPMChunker(
embedding_model=embedding_model,
chunk_size=512,
similarity_percentile=50,
)
chunks = chunker.chunk(sample_complex_markdown_text)

assert len(chunks) > 0
assert isinstance(chunks[0], SemanticChunk)
assert all([chunk.token_count <= 512 for chunk in chunks])
assert all([chunk.token_count > 0 for chunk in chunks])
assert all([chunk.text is not None for chunk in chunks])
assert all([chunk.start_index is not None for chunk in chunks])
assert all([chunk.end_index is not None for chunk in chunks])
assert all([chunk.sentences is not None for chunk in chunks])


if __name__ == "__main__":
pytest.main()

0 comments on commit 719e33b

Please sign in to comment.