Skip to content

Commit

Permalink
[fix] Correct the start and end indices for TokenChunker in Batch mode (
Browse files Browse the repository at this point in the history
#84) (#109)

* [fix] #84: Add proper indices for Batch Token Chunking

* [fix] Indices incorrect due to Overlap

* Add a test case for batch token indices verification
  • Loading branch information
bhavnicksm authored Dec 27, 2024
1 parent 7108923 commit ada6f30
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Token-based chunking."""

from itertools import accumulate
from typing import Any, Generator, List, Tuple, Union

from chonkie.types import Chunk
Expand Down Expand Up @@ -126,22 +127,33 @@ def _chunk_generator(
if end == len(tokens):
break

def _process_batch(self, chunks: List[Tuple[List[int], int, int]]) -> List[Chunk]:
def _process_batch(self,
chunks: List[Tuple[List[int], int, int]],
full_text: str) -> List[Chunk]:
"""Process a batch of chunks."""
token_lists = [tokens for tokens, _, _ in chunks]
texts = self._decode_batch(token_lists)

index_pairs = []
current_index = 0
for text in texts:
start_index = full_text.find(text, current_index)
end_index = start_index + len(text)
index_pairs.append((start_index, end_index))
current_index = end_index

return [
Chunk(text=text, start_index=start, end_index=end, token_count=end - start)
for text, (_, start, end) in zip(texts, chunks)
Chunk(text=text, start_index=start, end_index=end, token_count=len(tokens))
for text, (start, end), tokens in zip(texts, index_pairs, token_lists)
]

def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]:
"""Process a batch of texts."""
tokens_list = self._encode_batch(texts)
decoded_texts = self._decode_batch(tokens_list)
result = []

for tokens in tokens_list:
for tokens, text in zip(tokens_list, decoded_texts):
if not tokens:
result.append([])
continue
Expand All @@ -152,13 +164,13 @@ def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]:
for chunk_data in self._chunk_generator(tokens):
chunk_batch.append(chunk_data)

chunks.extend(self._process_batch(chunk_batch))
chunks.extend(self._process_batch(chunk_batch, text))
result.append(chunks)

return result

def chunk_batch(
self, texts: List[str], batch_size: int = None
self, texts: List[str], batch_size: Union[int, None] = None
) -> List[List[Chunk]]:
"""Split a batch of texts into their respective chunks.
Expand Down
6 changes: 6 additions & 0 deletions tests/chunker/test_token_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,11 @@ def test_token_chunker_token_counts(tokenizer, sample_text):
token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks]
assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text"

def test_token_chunker_indices_batch(tokenizer, sample_text):
"""Test that TokenChunker's indices correctly map to original text."""
chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128)
chunks = chunker.chunk_batch([sample_text]*10)[-1]
verify_chunk_indices(chunks, sample_text)

if __name__ == "__main__":
pytest.main()

0 comments on commit ada6f30

Please sign in to comment.