Skip to content

Commit

Permalink
feat: speed up chunking & add separator chunking (#48)
Browse files Browse the repository at this point in the history
* speed up chunking & add separator chunking

* add test code for splitter & reformat chunking methods

* typo

* fix overlap behaviour

* typo

* typo for type check
  • Loading branch information
rangehow authored Sep 19, 2024
1 parent 70bbb67 commit 13ce7d1
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 32 deletions.
109 changes: 93 additions & 16 deletions nano_graphrag/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import Union
from collections import Counter, defaultdict

import tiktoken

from ._utils import (
logger,
clean_str,
Expand All @@ -28,27 +30,102 @@
from .prompt import GRAPH_FIELD_SEP, PROMPTS




def chunking_by_token_size(
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
):
tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
results = []
for index, start in enumerate(
range(0, len(tokens), max_token_size - overlap_token_size)
tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024,
):
chunk_content = decode_tokens_by_tiktoken(
tokens[start : start + max_token_size], model_name=tiktoken_model
)
results.append(
{
"tokens": min(max_token_size, len(tokens) - start),
"content": chunk_content.strip(),
"chunk_order_index": index,
}
)

results=[]
for index,tokens in enumerate(tokens_list):
chunk_token=[]
lengths=[]
for start in range(0, len(tokens), max_token_size - overlap_token_size):

chunk_token.append(tokens[start : start + max_token_size])
lengths.append(min(max_token_size, len(tokens) - start))

# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token=tiktoken_model.decode_batch(chunk_token)
for i,chunk in enumerate(chunk_token):

results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id":doc_keys[index],
}
)

return results

def chunking_by_seperators(tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024 ):
from nano_graphrag._spliter import SeparatorSplitter

DEFAULT_SEPERATORS=[
# Paragraph separators
"\n\n",
"\r\n\r\n",
# Line breaks
"\n",
"\r\n",
# Sentence ending punctuation
"。", # Chinese period
".", # Full-width dot
".", # English period
"!", # Chinese exclamation mark
"!", # English exclamation mark
"?", # Chinese question mark
"?", # English question mark
# Whitespace characters
" ", # Space
"\t", # Tab
"\u3000", # Full-width space
# Special characters
"\u200b", # Zero-width space (used in some Asian languages)
]

splitter=SeparatorSplitter(separators=[tiktoken_model.encode(s) for s in DEFAULT_SEPERATORS],chunk_size=max_token_size,chunk_overlap=overlap_token_size)
results=[]
for index,tokens in enumerate(tokens_list):
chunk_token=splitter.split_tokens(tokens)
lengths=[len(c) for c in chunk_token]

# here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely
chunk_token=tiktoken_model.decode_batch(chunk_token)
for i,chunk in enumerate(chunk_token):

results.append(
{
"tokens": lengths[i],
"content": chunk.strip(),
"chunk_order_index": i,
"full_doc_id":doc_keys[index],
}
)

return results


def get_chunks(new_docs,chunk_func=chunking_by_token_size,**chunk_func_params):
inserting_chunks = {}

new_docs_list=list(new_docs.items())
docs=[new_doc[1]["content"] for new_doc in new_docs_list]
doc_keys=[new_doc[0] for new_doc in new_docs_list]


ENCODER = tiktoken.encoding_for_model("gpt-4o")
tokens=ENCODER.encode_batch(docs,num_threads=16)
chunks=chunk_func(tokens,doc_keys=doc_keys,tiktoken_model=ENCODER,**chunk_func_params)

for chunk in chunks:
inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk})

return inserting_chunks


async def _handle_entity_relation_summary(
entity_or_relation_name: str,
description: str,
Expand Down
94 changes: 94 additions & 0 deletions nano_graphrag/_spliter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import List, Optional, Union, Literal

class SeparatorSplitter:
def __init__(
self,
separators: Optional[List[List[int]]] = None,
keep_separator: Union[bool, Literal["start", "end"]] = "end",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: callable = len,
):
self._separators = separators or []
self._keep_separator = keep_separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function

def split_tokens(self, tokens: List[int]) -> List[List[int]]:
splits = self._split_tokens_with_separators(tokens)
return self._merge_splits(splits)

def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]:
splits = []
current_split = []
i = 0
while i < len(tokens):
separator_found = False
for separator in self._separators:
if tokens[i:i+len(separator)] == separator:
if self._keep_separator in [True, "end"]:
current_split.extend(separator)
if current_split:
splits.append(current_split)
current_split = []
if self._keep_separator == "start":
current_split.extend(separator)
i += len(separator)
separator_found = True
break
if not separator_found:
current_split.append(tokens[i])
i += 1
if current_split:
splits.append(current_split)
return [s for s in splits if s]

def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]:
if not splits:
return []

merged_splits = []
current_chunk = []

for split in splits:
if not current_chunk:
current_chunk = split
elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size:
current_chunk.extend(split)
else:
merged_splits.append(current_chunk)
current_chunk = split

if current_chunk:
merged_splits.append(current_chunk)

if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size:
return self._split_chunk(merged_splits[0])

if self._chunk_overlap > 0:
return self._enforce_overlap(merged_splits)

return merged_splits

def _split_chunk(self, chunk: List[int]) -> List[List[int]]:
result = []
for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap):
new_chunk = chunk[i:i + self._chunk_size]
if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加
result.append(new_chunk)
return result

def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]:
result = []
for i, chunk in enumerate(chunks):
if i == 0:
result.append(chunk)
else:
overlap = chunks[i-1][-self._chunk_overlap:]
new_chunk = overlap + chunk
if self._length_function(new_chunk) > self._chunk_size:
new_chunk = new_chunk[:self._chunk_size]
result.append(new_chunk)
return result

25 changes: 9 additions & 16 deletions nano_graphrag/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union, cast

import tiktoken


from ._llm import (
gpt_4o_complete,
Expand All @@ -18,6 +20,7 @@
chunking_by_token_size,
extract_entities,
generate_community_report,
get_chunks,
local_query,
global_query,
naive_query,
Expand Down Expand Up @@ -65,7 +68,7 @@ class GraphRAG:
enable_naive_rag: bool = False

# text chunking
chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
chunk_func: Callable[[list[list[int]],List[str],tiktoken.Encoding, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]]] = chunking_by_token_size
chunk_token_size: int = 1200
chunk_overlap_token_size: int = 100
tiktoken_model_name: str = "gpt-4o"
Expand Down Expand Up @@ -263,21 +266,11 @@ async def ainsert(self, string_or_strings):
logger.info(f"[New Docs] inserting {len(new_docs)} docs")

# ---------- chunking
inserting_chunks = {}
for doc_key, doc in new_docs.items():
chunks = {
compute_mdhash_id(dp["content"], prefix="chunk-"): {
**dp,
"full_doc_id": doc_key,
}
for dp in self.chunk_func(
doc["content"],
overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size,
tiktoken_model=self.tiktoken_model_name,
)
}
inserting_chunks.update(chunks)

inserting_chunks = get_chunks(new_docs=new_docs,chunk_func=self.chunk_func,overlap_token_size=self.chunk_overlap_token_size,
max_token_size=self.chunk_token_size)


_add_chunk_keys = await self.text_chunks.filter_keys(
list(inserting_chunks.keys())
)
Expand Down
63 changes: 63 additions & 0 deletions tests/test_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
from typing import List

from nano_graphrag._spliter import SeparatorSplitter

# Assuming the SeparatorSplitter class is already imported

class TestSeparatorSplitter(unittest.TestCase):

def setUp(self):
self.tokenize = lambda text: [ord(c) for c in text] # Simple tokenizer for testing
self.detokenize = lambda tokens: ''.join(chr(t) for t in tokens)

def test_split_with_custom_separator(self):
splitter = SeparatorSplitter(
separators=[self.tokenize('\n'), self.tokenize('.')],
chunk_size=19,
chunk_overlap=0,
keep_separator="end"
)
text = "This is a test.\nAnother test."
tokens = self.tokenize(text)
expected = [
self.tokenize("This is a test.\n"),
self.tokenize("Another test."),
]
result = splitter.split_tokens(tokens)

self.assertEqual(result, expected)

def test_chunk_size_limit(self):
splitter = SeparatorSplitter(
chunk_size=5,
chunk_overlap=0,
separators=[self.tokenize("\n")]
)
text = "1234567890"
tokens = self.tokenize(text)
expected = [
self.tokenize("12345"),
self.tokenize("67890")
]
result = splitter.split_tokens(tokens)
self.assertEqual(result, expected)

def test_chunk_overlap(self):
splitter = SeparatorSplitter(
chunk_size=5,
chunk_overlap=2,
separators=[self.tokenize("\n")]
)
text = "1234567890"
tokens = self.tokenize(text)
expected = [
self.tokenize("12345"),
self.tokenize("45678"),
self.tokenize("7890"),
]
result = splitter.split_tokens(tokens)
self.assertEqual(result, expected)

if __name__ == '__main__':
unittest.main()

0 comments on commit 13ce7d1

Please sign in to comment.