diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index f6feb01..bee35ae 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -4,6 +4,8 @@ from typing import Union from collections import Counter, defaultdict +import tiktoken + from ._utils import ( logger, clean_str, @@ -28,27 +30,102 @@ from .prompt import GRAPH_FIELD_SEP, PROMPTS + + def chunking_by_token_size( - content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o" -): - tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model) - results = [] - for index, start in enumerate( - range(0, len(tokens), max_token_size - overlap_token_size) + tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024, ): - chunk_content = decode_tokens_by_tiktoken( - tokens[start : start + max_token_size], model_name=tiktoken_model - ) - results.append( - { - "tokens": min(max_token_size, len(tokens) - start), - "content": chunk_content.strip(), - "chunk_order_index": index, - } - ) + + results=[] + for index,tokens in enumerate(tokens_list): + chunk_token=[] + lengths=[] + for start in range(0, len(tokens), max_token_size - overlap_token_size): + + chunk_token.append(tokens[start : start + max_token_size]) + lengths.append(min(max_token_size, len(tokens) - start)) + + # here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely + chunk_token=tiktoken_model.decode_batch(chunk_token) + for i,chunk in enumerate(chunk_token): + + results.append( + { + "tokens": lengths[i], + "content": chunk.strip(), + "chunk_order_index": i, + "full_doc_id":doc_keys[index], + } + ) + + return results + +def chunking_by_seperators(tokens_list: list[list[int]], doc_keys,tiktoken_model, overlap_token_size=128, max_token_size=1024 ): + from nano_graphrag._spliter import SeparatorSplitter + + DEFAULT_SEPERATORS=[ + # Paragraph separators + "\n\n", + "\r\n\r\n", + # Line breaks + "\n", + "\r\n", + # Sentence ending punctuation + "。", # Chinese period + ".", # Full-width dot + ".", # English period + "!", # Chinese exclamation mark + "!", # English exclamation mark + "?", # Chinese question mark + "?", # English question mark + # Whitespace characters + " ", # Space + "\t", # Tab + "\u3000", # Full-width space + # Special characters + "\u200b", # Zero-width space (used in some Asian languages) + ] + + splitter=SeparatorSplitter(separators=[tiktoken_model.encode(s) for s in DEFAULT_SEPERATORS],chunk_size=max_token_size,chunk_overlap=overlap_token_size) + results=[] + for index,tokens in enumerate(tokens_list): + chunk_token=splitter.split_tokens(tokens) + lengths=[len(c) for c in chunk_token] + + # here somehow tricky, since the whole chunk tokens is list[list[list[int]]] for corpus(doc(chunk)),so it can't be decode entirely + chunk_token=tiktoken_model.decode_batch(chunk_token) + for i,chunk in enumerate(chunk_token): + + results.append( + { + "tokens": lengths[i], + "content": chunk.strip(), + "chunk_order_index": i, + "full_doc_id":doc_keys[index], + } + ) + return results +def get_chunks(new_docs,chunk_func=chunking_by_token_size,**chunk_func_params): + inserting_chunks = {} + + new_docs_list=list(new_docs.items()) + docs=[new_doc[1]["content"] for new_doc in new_docs_list] + doc_keys=[new_doc[0] for new_doc in new_docs_list] + + + ENCODER = tiktoken.encoding_for_model("gpt-4o") + tokens=ENCODER.encode_batch(docs,num_threads=16) + chunks=chunk_func(tokens,doc_keys=doc_keys,tiktoken_model=ENCODER,**chunk_func_params) + + for chunk in chunks: + inserting_chunks.update({compute_mdhash_id(chunk["content"], prefix="chunk-"):chunk}) + + return inserting_chunks + + async def _handle_entity_relation_summary( entity_or_relation_name: str, description: str, diff --git a/nano_graphrag/_spliter.py b/nano_graphrag/_spliter.py new file mode 100644 index 0000000..1054d17 --- /dev/null +++ b/nano_graphrag/_spliter.py @@ -0,0 +1,94 @@ +from typing import List, Optional, Union, Literal + +class SeparatorSplitter: + def __init__( + self, + separators: Optional[List[List[int]]] = None, + keep_separator: Union[bool, Literal["start", "end"]] = "end", + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: callable = len, + ): + self._separators = separators or [] + self._keep_separator = keep_separator + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + + def split_tokens(self, tokens: List[int]) -> List[List[int]]: + splits = self._split_tokens_with_separators(tokens) + return self._merge_splits(splits) + + def _split_tokens_with_separators(self, tokens: List[int]) -> List[List[int]]: + splits = [] + current_split = [] + i = 0 + while i < len(tokens): + separator_found = False + for separator in self._separators: + if tokens[i:i+len(separator)] == separator: + if self._keep_separator in [True, "end"]: + current_split.extend(separator) + if current_split: + splits.append(current_split) + current_split = [] + if self._keep_separator == "start": + current_split.extend(separator) + i += len(separator) + separator_found = True + break + if not separator_found: + current_split.append(tokens[i]) + i += 1 + if current_split: + splits.append(current_split) + return [s for s in splits if s] + + def _merge_splits(self, splits: List[List[int]]) -> List[List[int]]: + if not splits: + return [] + + merged_splits = [] + current_chunk = [] + + for split in splits: + if not current_chunk: + current_chunk = split + elif self._length_function(current_chunk) + self._length_function(split) <= self._chunk_size: + current_chunk.extend(split) + else: + merged_splits.append(current_chunk) + current_chunk = split + + if current_chunk: + merged_splits.append(current_chunk) + + if len(merged_splits) == 1 and self._length_function(merged_splits[0]) > self._chunk_size: + return self._split_chunk(merged_splits[0]) + + if self._chunk_overlap > 0: + return self._enforce_overlap(merged_splits) + + return merged_splits + + def _split_chunk(self, chunk: List[int]) -> List[List[int]]: + result = [] + for i in range(0, len(chunk), self._chunk_size - self._chunk_overlap): + new_chunk = chunk[i:i + self._chunk_size] + if len(new_chunk) > self._chunk_overlap: # 只有当 chunk 长度大于 overlap 时才添加 + result.append(new_chunk) + return result + + def _enforce_overlap(self, chunks: List[List[int]]) -> List[List[int]]: + result = [] + for i, chunk in enumerate(chunks): + if i == 0: + result.append(chunk) + else: + overlap = chunks[i-1][-self._chunk_overlap:] + new_chunk = overlap + chunk + if self._length_function(new_chunk) > self._chunk_size: + new_chunk = new_chunk[:self._chunk_size] + result.append(new_chunk) + return result + diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 3cf938a..530ad83 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -5,6 +5,8 @@ from functools import partial from typing import Callable, Dict, List, Optional, Type, Union, cast +import tiktoken + from ._llm import ( gpt_4o_complete, @@ -18,6 +20,7 @@ chunking_by_token_size, extract_entities, generate_community_report, + get_chunks, local_query, global_query, naive_query, @@ -65,7 +68,7 @@ class GraphRAG: enable_naive_rag: bool = False # text chunking - chunk_func: Callable[[str, Optional[int], Optional[int], Optional[str]], List[Dict[str, Union[str, int]]]] = chunking_by_token_size + chunk_func: Callable[[list[list[int]],List[str],tiktoken.Encoding, Optional[int], Optional[int], ], List[Dict[str, Union[str, int]]]] = chunking_by_token_size chunk_token_size: int = 1200 chunk_overlap_token_size: int = 100 tiktoken_model_name: str = "gpt-4o" @@ -263,21 +266,11 @@ async def ainsert(self, string_or_strings): logger.info(f"[New Docs] inserting {len(new_docs)} docs") # ---------- chunking - inserting_chunks = {} - for doc_key, doc in new_docs.items(): - chunks = { - compute_mdhash_id(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_key, - } - for dp in self.chunk_func( - doc["content"], - overlap_token_size=self.chunk_overlap_token_size, - max_token_size=self.chunk_token_size, - tiktoken_model=self.tiktoken_model_name, - ) - } - inserting_chunks.update(chunks) + + inserting_chunks = get_chunks(new_docs=new_docs,chunk_func=self.chunk_func,overlap_token_size=self.chunk_overlap_token_size, + max_token_size=self.chunk_token_size) + + _add_chunk_keys = await self.text_chunks.filter_keys( list(inserting_chunks.keys()) ) diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 0000000..551de01 --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,63 @@ +import unittest +from typing import List + +from nano_graphrag._spliter import SeparatorSplitter + +# Assuming the SeparatorSplitter class is already imported + +class TestSeparatorSplitter(unittest.TestCase): + + def setUp(self): + self.tokenize = lambda text: [ord(c) for c in text] # Simple tokenizer for testing + self.detokenize = lambda tokens: ''.join(chr(t) for t in tokens) + + def test_split_with_custom_separator(self): + splitter = SeparatorSplitter( + separators=[self.tokenize('\n'), self.tokenize('.')], + chunk_size=19, + chunk_overlap=0, + keep_separator="end" + ) + text = "This is a test.\nAnother test." + tokens = self.tokenize(text) + expected = [ + self.tokenize("This is a test.\n"), + self.tokenize("Another test."), + ] + result = splitter.split_tokens(tokens) + + self.assertEqual(result, expected) + + def test_chunk_size_limit(self): + splitter = SeparatorSplitter( + chunk_size=5, + chunk_overlap=0, + separators=[self.tokenize("\n")] + ) + text = "1234567890" + tokens = self.tokenize(text) + expected = [ + self.tokenize("12345"), + self.tokenize("67890") + ] + result = splitter.split_tokens(tokens) + self.assertEqual(result, expected) + + def test_chunk_overlap(self): + splitter = SeparatorSplitter( + chunk_size=5, + chunk_overlap=2, + separators=[self.tokenize("\n")] + ) + text = "1234567890" + tokens = self.tokenize(text) + expected = [ + self.tokenize("12345"), + self.tokenize("45678"), + self.tokenize("7890"), + ] + result = splitter.split_tokens(tokens) + self.assertEqual(result, expected) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file