Skip to content

Commit

Permalink
Merge pull request #126 from Udayk02/development
Browse files Browse the repository at this point in the history
[FIX] #116: Incorrect`start_index` when `chunk_overlap` is not 0
  • Loading branch information
bhavnicksm authored Jan 4, 2025
2 parents 552a068 + 5d401a0 commit 5b75303
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/chonkie/chunker/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,36 @@ def __init__(

def _create_chunks(
self,
chunk_texts: List[str],
token_counts: List[int],
decoded_text: str,
token_groups: List[List[int]]
) -> List[Chunk]:
"""Create chunks from a list of texts."""
# package everything as Chunk objects and send out the result
chunk_texts = self._decode_batch(token_groups)
chunks = []
current_index = 0
for chunk_text, token_count in zip(chunk_texts, token_counts):
start_index = decoded_text.find(
chunk_text, current_index
) # Find needs to be run every single time because of unknown overlap length
end_index = start_index + len(chunk_text)

if (self.chunk_overlap > 0):
overlap_tokens_space = [
# we get the space taken by the overlapping text, that gives you the start_index for the next chunk
len(overlap_text)
for overlap_text in self._decode_batch([token_group[-(self.chunk_overlap - (self.chunk_size - len(token_group))):]
for token_group in token_groups])
]

for i, (chunk_text, token_count) in enumerate(zip(chunk_texts, token_counts)):
end_index = current_index + len(chunk_text)
chunks.append(
Chunk(
text=chunk_text,
start_index=start_index,
start_index=current_index,
end_index=end_index,
token_count=token_count,
)
)
current_index = end_index

current_index = end_index - (overlap_tokens_space[i] if self.chunk_overlap > 0 else 0)

return chunks

def chunk(self, text: str) -> List[Chunk]:
Expand All @@ -91,9 +99,6 @@ def chunk(self, text: str) -> List[Chunk]:
# Encode full text
text_tokens = self._encode(text)

# We decode the text because the tokenizer might result in a different output than text
decoded_text = self._decode(text_tokens)

# Calculate chunk positions
token_groups = [
text_tokens[
Expand All @@ -107,11 +112,7 @@ def chunk(self, text: str) -> List[Chunk]:
len(toks) for toks in token_groups
] # get the token counts; it's prolly chunk_size, but len doesn't take too long

chunk_texts = self._decode_batch(
token_groups
) # decrease the time by decoding in one go (?)

chunks = self._create_chunks(chunk_texts, token_counts, decoded_text)
chunks = self._create_chunks(token_counts, token_groups)

return chunks

Expand Down

0 comments on commit 5b75303

Please sign in to comment.