Skip to content

Commit

Permalink
Refactor code for improved readability and consistency
Browse files Browse the repository at this point in the history
- Cleaned up whitespace and formatting across multiple files, including __init__.py, context.py, base.py, semantic.py, sentence.py, and overlap.py.
- Enhanced docstrings for clarity and consistency in the Context and Chunk classes.
- Adjusted method signatures and internal logic for better readability in the BaseChunker and SentenceChunker classes.
- Updated test files to improve formatting and ensure consistent style.
- Ensured consistent use of commas and spacing in function definitions and calls.
  • Loading branch information
bhavnicksm committed Dec 5, 2024
1 parent 3781536 commit 621061f
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 124 deletions.
4 changes: 1 addition & 3 deletions src/chonkie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Main package for Chonkie."""

from .context import Context

from .chunker import (
BaseChunker,
Chunk,
Expand All @@ -15,14 +13,14 @@
TokenChunker,
WordChunker,
)
from .context import Context
from .embeddings import (
AutoEmbeddings,
BaseEmbeddings,
Model2VecEmbeddings,
OpenAIEmbeddings,
SentenceTransformerEmbeddings,
)

from .refinery import (
BaseRefinery,
OverlapRefinery,
Expand Down
13 changes: 4 additions & 9 deletions src/chonkie/chunker/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Base classes for chunking text."""

import importlib

import inspect
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -58,7 +57,7 @@ def __repr__(self) -> str:
def __iter__(self):
"""Return an iterator over the chunk."""
return iter(self.text)

def __getitem__(self, index: int):
"""Return the item at the given index."""
return self.text[index]
Expand All @@ -73,15 +72,16 @@ def copy(self) -> "Chunk":
)



class BaseChunker(ABC):
"""Abstract base class for all chunker implementations.
All chunker implementations should inherit from this class and implement
the chunk() method according to their specific chunking strategy.
"""

def __init__(self, tokenizer_or_token_counter: Union[str, Any, Callable[[str], int]]):
def __init__(
self, tokenizer_or_token_counter: Union[str, Any, Callable[[str], int]]
):
"""Initialize the chunker with a tokenizer.
Args:
Expand Down Expand Up @@ -116,7 +116,6 @@ def _get_tokenizer_backend(self):
f"Tokenizer backend {str(type(self.tokenizer))} not supported"
)


def _load_tokenizer(self, tokenizer_name: str):
"""Load a tokenizer based on the backend."""
try:
Expand Down Expand Up @@ -186,7 +185,6 @@ def _encode(self, text: str):
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _encode_batch(self, texts: List[str]):
"""Encode a batch of texts using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
Expand All @@ -205,7 +203,6 @@ def _encode_batch(self, texts: List[str]):
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _decode(self, tokens) -> str:
"""Decode tokens using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
Expand All @@ -219,7 +216,6 @@ def _decode(self, tokens) -> str:
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
"""Decode a batch of token lists using the backend tokenizer."""
if self._tokenizer_backend == "transformers":
Expand All @@ -233,7 +229,6 @@ def _decode_batch(self, token_lists: List[List[int]]) -> List[str]:
f"Tokenizer backend {self._tokenizer_backend} not supported."
)


def _count_tokens(self, text: str) -> int:
"""Count tokens in text using the backend tokenizer."""
return self.token_counter(text)
Expand Down
9 changes: 5 additions & 4 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SemanticChunk(SentenceChunk):
sentences: List of SemanticSentence objects in the chunk
"""

sentences: List[SemanticSentence] = field(default_factory=list)


Expand Down Expand Up @@ -128,9 +128,10 @@ def __init__(

# Probably the dependency is not installed
if self.embedding_model is None:
raise ImportError("embedding_model is not a valid embedding model",
"Please install the `semantic` extra to use this feature")

raise ImportError(
"embedding_model is not a valid embedding model",
"Please install the `semantic` extra to use this feature",
)

# Keeping the tokenizer the same as the sentence model is important
# for the group semantic meaning to be calculated properly
Expand Down
93 changes: 48 additions & 45 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from bisect import bisect_left
from dataclasses import dataclass, field

from itertools import accumulate
from typing import Any, List, Union

Expand All @@ -26,6 +25,7 @@ class Sentence:
end_index: int
token_count: int


@dataclass
class SentenceChunk(Chunk):
"""Dataclass representing a sentence chunk with metadata.
Expand All @@ -40,9 +40,11 @@ class SentenceChunk(Chunk):
sentences: List of Sentence objects in the chunk
"""

# Don't redeclare inherited fields
sentences: List[Sentence] = field(default_factory=list)


class SentenceChunker(BaseChunker):
"""SentenceChunker splits the sentences in a text based on token limits and sentence boundaries.
Expand Down Expand Up @@ -238,18 +240,20 @@ def _get_token_counts(self, sentences: List[str]) -> List[int]:

def _estimate_token_counts(self, text: str) -> int:
"""Estimate token count using character length."""
CHARS_PER_TOKEN = 6. # Avg. char per token for llama3 is b/w 6-7
CHARS_PER_TOKEN = 6.0 # Avg. char per token for llama3 is b/w 6-7
if type(text) is str:
return max(1, int(len(text) / CHARS_PER_TOKEN))
elif type(text) is list and type(text[0]) is str:
return [max(1, int(len(t) / CHARS_PER_TOKEN)) for t in text]
else:
raise ValueError(f"Unknown type passed to _estimate_token_count: {type(text)}")

raise ValueError(
f"Unknown type passed to _estimate_token_count: {type(text)}"
)

def _get_feedback(self, estimate: int, actual: int) -> float:
"""Validate against the actual token counts and correct the estimates."""
feedback = 1 - ((estimate-actual)/estimate)
return feedback
feedback = 1 - ((estimate - actual) / estimate)
return feedback

def _prepare_sentences(self, text: str) -> List[Sentence]:
"""Split text into sentences and calculate token counts for each sentence.
Expand All @@ -261,67 +265,62 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:
List of Sentence objects
"""
# Split text into sentences
# Split text into sentences
sentence_texts = self._split_sentences(text)
if not sentence_texts:
return []

# Calculate positions once
positions = []
current_pos = 0
for sent in sentence_texts:
positions.append(current_pos)
current_pos += len(sent) + 1 # +1 for space/separator

if not self.use_approximate:
# Get accurate token counts in batch
token_counts = self._get_token_counts(sentence_texts)
else:
# Estimate token counts using character length
token_counts = self._estimate_token_counts(sentence_texts)

# Create sentence objects
return [
Sentence(
text=sent,
start_index=pos,
end_index=pos + len(sent),
token_count=count
text=sent, start_index=pos, end_index=pos + len(sent), token_count=count
)
for sent, pos, count in zip(sentence_texts, positions, token_counts)
]

def _prepare_sentences(self, text: str) -> List[Sentence]:
"""Prepare sentences with either estimated or accurate token counts."""
# Split text into sentences
sentence_texts = self._split_sentences(text)
if not sentence_texts:
return []

# Calculate positions once
positions = []
current_pos = 0
for sent in sentence_texts:
positions.append(current_pos)
current_pos += len(sent) + 1 # +1 for space/separator

if not self.use_approximate:
# Get accurate token counts in batch
token_counts = self._get_token_counts(sentence_texts)
else:
# Estimate token counts using character length
token_counts = self._estimate_token_counts(sentence_texts)

# Create sentence objects
return [
Sentence(
text=sent,
start_index=pos,
end_index=pos + len(sent),
token_count=count
text=sent, start_index=pos, end_index=pos + len(sent), token_count=count
)
for sent, pos, count in zip(sentence_texts, positions, token_counts)
]

def _create_chunk(self, sentences: List[Sentence], token_count: int) -> Chunk:
"""Create a chunk from a list of sentences.
Expand Down Expand Up @@ -354,41 +353,43 @@ def chunk(self, text: str) -> List[Chunk]:
"""
if not text.strip():
return []

# Get prepared sentences with token counts
sentences = self._prepare_sentences(text) #28mus
sentences = self._prepare_sentences(text) # 28mus
if not sentences:
return []

# Pre-calculate cumulative token counts for bisect
# Add 1 token for spaces between sentences
token_sums = list(accumulate(
[s.token_count for s in sentences], lambda a,b: a + b, initial=0
))

token_sums = list(
accumulate(
[s.token_count for s in sentences], lambda a, b: a + b, initial=0
)
)

chunks = []
feedback = 1.
feedback = 1.0
pos = 0

while pos < len(sentences):
# use updated feedback on the token sums
token_sums = [int(s * feedback) for s in token_sums]

# Use bisect_left to find initial split point
target_tokens = token_sums[pos] + self.chunk_size
split_idx = bisect_left(token_sums, target_tokens) -1
split_idx = bisect_left(token_sums, target_tokens) - 1
split_idx = min(split_idx, len(sentences))

# Ensure we include at least one sentence beyond pos
split_idx = max(split_idx, pos + 1)

# Handle minimum sentences requirement
if split_idx - pos < self.min_sentences_per_chunk:
split_idx = pos + self.min_sentences_per_chunk

# Get the estimated token count
estimate = token_sums[split_idx] - token_sums[pos]
estimate = token_sums[split_idx] - token_sums[pos]

# Get candidate sentences and verify actual token count
chunk_sentences = sentences[pos:split_idx]
chunk_text = " ".join(s.text for s in chunk_sentences)
Expand All @@ -397,36 +398,38 @@ def chunk(self, text: str) -> List[Chunk]:
# Given the actual token_count and the estimate, get a feedback value for the next loop
feedback = self._get_feedback(estimate, actual)
# print(f"Estimate: {estimate} Actual: {actual} feedback: {feedback}")

# Back off one sentence at a time if we exceeded chunk size
while (actual > self.chunk_size and
len(chunk_sentences) > self.min_sentences_per_chunk):
while (
actual > self.chunk_size
and len(chunk_sentences) > self.min_sentences_per_chunk
):
split_idx -= 1
chunk_sentences = sentences[pos:split_idx]
chunk_text = " ".join(s.text for s in chunk_sentences)
actual = len(self._encode(chunk_text))

chunks.append(self._create_chunk(chunk_sentences, actual))

# Calculate next position with overlap
if self.chunk_overlap > 0 and split_idx < len(sentences):
# Calculate how many sentences we need for overlap
overlap_tokens = 0
overlap_idx = split_idx - 1

while overlap_idx > pos and overlap_tokens < self.chunk_overlap:
sent = sentences[overlap_idx]
next_tokens = overlap_tokens + sent.token_count + 1 # +1 for space
if next_tokens > self.chunk_overlap:
break
overlap_tokens = next_tokens
overlap_idx -= 1

# Move position to after the overlap
pos = overlap_idx + 1
else:
pos = split_idx

return chunks

def __repr__(self) -> str:
Expand Down
Loading

0 comments on commit 621061f

Please sign in to comment.