From ada6f306cfdf97e3cefcb9cb49c46b217880e239 Mon Sep 17 00:00:00 2001 From: Bhavnick Minhas <11348086+bhavnicksm@users.noreply.github.com> Date: Fri, 27 Dec 2024 16:21:06 +0530 Subject: [PATCH] [fix] Correct the start and end indices for TokenChunker in Batch mode (#84) (#109) * [fix] #84: Add proper indices for Batch Token Chunking * [fix] Indices incorrect due to Overlap * Add a test case for batch token indices verification --- src/chonkie/chunker/token.py | 24 ++++++++++++++++++------ tests/chunker/test_token_chunker.py | 6 ++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/chonkie/chunker/token.py b/src/chonkie/chunker/token.py index b6f9ee0..1912163 100644 --- a/src/chonkie/chunker/token.py +++ b/src/chonkie/chunker/token.py @@ -1,5 +1,6 @@ """Token-based chunking.""" +from itertools import accumulate from typing import Any, Generator, List, Tuple, Union from chonkie.types import Chunk @@ -126,22 +127,33 @@ def _chunk_generator( if end == len(tokens): break - def _process_batch(self, chunks: List[Tuple[List[int], int, int]]) -> List[Chunk]: + def _process_batch(self, + chunks: List[Tuple[List[int], int, int]], + full_text: str) -> List[Chunk]: """Process a batch of chunks.""" token_lists = [tokens for tokens, _, _ in chunks] texts = self._decode_batch(token_lists) + index_pairs = [] + current_index = 0 + for text in texts: + start_index = full_text.find(text, current_index) + end_index = start_index + len(text) + index_pairs.append((start_index, end_index)) + current_index = end_index + return [ - Chunk(text=text, start_index=start, end_index=end, token_count=end - start) - for text, (_, start, end) in zip(texts, chunks) + Chunk(text=text, start_index=start, end_index=end, token_count=len(tokens)) + for text, (start, end), tokens in zip(texts, index_pairs, token_lists) ] def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]: """Process a batch of texts.""" tokens_list = self._encode_batch(texts) + decoded_texts = self._decode_batch(tokens_list) result = [] - for tokens in tokens_list: + for tokens, text in zip(tokens_list, decoded_texts): if not tokens: result.append([]) continue @@ -152,13 +164,13 @@ def _process_text_batch(self, texts: List[str]) -> List[List[Chunk]]: for chunk_data in self._chunk_generator(tokens): chunk_batch.append(chunk_data) - chunks.extend(self._process_batch(chunk_batch)) + chunks.extend(self._process_batch(chunk_batch, text)) result.append(chunks) return result def chunk_batch( - self, texts: List[str], batch_size: int = None + self, texts: List[str], batch_size: Union[int, None] = None ) -> List[List[Chunk]]: """Split a batch of texts into their respective chunks. diff --git a/tests/chunker/test_token_chunker.py b/tests/chunker/test_token_chunker.py index 3394b1d..b990bcc 100644 --- a/tests/chunker/test_token_chunker.py +++ b/tests/chunker/test_token_chunker.py @@ -331,5 +331,11 @@ def test_token_chunker_token_counts(tokenizer, sample_text): token_counts = [len(tokenizer.encode(chunk.text)) for chunk in chunks] assert all([chunk.token_count == token_count for chunk, token_count in zip(chunks, token_counts)]), "All chunks must have a token count equal to the length of the encoded text" +def test_token_chunker_indices_batch(tokenizer, sample_text): + """Test that TokenChunker's indices correctly map to original text.""" + chunker = TokenChunker(tokenizer=tokenizer, chunk_size=512, chunk_overlap=128) + chunks = chunker.chunk_batch([sample_text]*10)[-1] + verify_chunk_indices(chunks, sample_text) + if __name__ == "__main__": pytest.main()