diff --git a/examples/benchmarks/dspy_entity.py b/examples/benchmarks/dspy_entity.py index bac40c9..812a048 100644 --- a/examples/benchmarks/dspy_entity.py +++ b/examples/benchmarks/dspy_entity.py @@ -1,57 +1,19 @@ import dspy import os from dotenv import load_dotenv -from openai import AsyncOpenAI, OpenAI +from openai import AsyncOpenAI import logging import asyncio -from nano_graphrag._op import extract_entities, extract_entities_dspy +from nano_graphrag.entity_extraction.extract import extract_entities_dspy from nano_graphrag._storage import NetworkXStorage, BaseKVStorage from nano_graphrag._utils import compute_mdhash_id, compute_args_hash -from nano_graphrag.prompt import PROMPTS WORKING_DIR = "./nano_graphrag_cache_dspy_entity" load_dotenv() -logging.basicConfig(level=logging.WARNING) -logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) - - -class DeepSeek(dspy.Module): - def __init__(self, model, api_key, **kwargs): - self.model = model - self.api_key = api_key - self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com") - self.provider = "default", - self.history = [] - self.kwargs = { - "temperature": 0.2, - "max_tokens": 2048, - **kwargs - } - - def basic_request(self, prompt, **kwargs): - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} - ], - stream=False, - **self.kwargs - ) - self.history.append({"prompt": prompt, "response": response}) - return response - - def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): - response = self.basic_request(prompt, **kwargs) - completions = [choice.message.content for choice in response.choices] - return completions - - def inspect_history(self, n: int = 1): - if len(self.history) < n: - return self.history - return self.history[-n:] +logger = logging.getLogger("nano-graphrag") +logger.setLevel(logging.DEBUG) async def deepseepk_model_if_cache( @@ -88,64 +50,12 @@ async def deepseepk_model_if_cache( return response.choices[0].message.content -class EntityTypeExtractionSignature(dspy.Signature): - input_text = dspy.InputField(desc="The text to extract entity types from") - entity_types = dspy.OutputField(desc="List of entity types present in the text") - - -class EntityExtractionSignature(dspy.Signature): - input_text = dspy.InputField(desc="The text to extract entities and relationships from") - entities = dspy.OutputField(desc="List of extracted entities with their types and descriptions") - relationships = dspy.OutputField(desc="List of relationships between entities, including descriptions and importance scores") - reasoning = dspy.OutputField(desc="Step-by-step reasoning for entity and relationship extraction") - - -class EntityExtractor(dspy.Module): - def __init__(self): - super().__init__() - self.type_extractor = dspy.ChainOfThought(EntityTypeExtractionSignature) - self.cot = dspy.ChainOfThought(EntityExtractionSignature) - - def forward(self, input_text): - type_result = self.type_extractor(input_text=input_text) - entity_types = type_result.entity_types - prompt_template = PROMPTS["entity_extraction"] - formatted_prompt = prompt_template.format( - input_text=input_text, - entity_types=entity_types, - tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], - record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], - completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"] - ) - return self.cot(input_text=formatted_prompt) - - -# def extract_entities_dspy(text): -# dspy_extractor = EntityExtractor() -# dspy_result = dspy_extractor(input_text=text) - -# print("DSPY Result:") -# print("\nReasoning:") -# print(dspy_result.reasoning) -# print("\nEntities:") -# entities = dspy_result.entities.split(PROMPTS["DEFAULT_RECORD_DELIMITER"]) -# for entity in entities: -# if entity.strip(): -# print(entity.strip()) - -# print("\nRelationships:") -# relationships = dspy_result.relationships.split(PROMPTS["DEFAULT_RECORD_DELIMITER"]) -# for relationship in relationships: -# if relationship.strip(): -# print(relationship.strip()) - - -async def nano_entity_extraction(text): +async def nano_entity_extraction(text: str, system_prompt: str = None): graph_storage = NetworkXStorage(namespace="test", global_config={ "working_dir": WORKING_DIR, "entity_summary_to_max_tokens": 500, - "cheap_model_func": deepseepk_model_if_cache, - "best_model_func": deepseepk_model_if_cache, + "cheap_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs), + "best_model_func": lambda *args, **kwargs: deepseepk_model_if_cache(*args, system_prompt=system_prompt, **kwargs), "cheap_model_max_token_size": 4096, "best_model_max_token_size": 4096, "tiktoken_model_name": "gpt-4o", @@ -176,11 +86,25 @@ async def nano_entity_extraction(text): if __name__ == "__main__": - lm = DeepSeek(model="deepseek-chat", api_key=os.environ["DEEPSEEK_API_KEY"]) + system_prompt = """ + You are a world-class AI system, capable of complex reasoning and reflection. + Reason through the query, and then provide your final response. + If you detect that you made a mistake in your reasoning at any point, correct yourself. + Think carefully. + """ + lm = dspy.OpenAI( + model="deepseek-chat", + model_type="chat", + api_key=os.environ["DEEPSEEK_API_KEY"], + base_url=os.environ["DEEPSEEK_BASE_URL"], + system_prompt=system_prompt, + temperature=0.3, + top_p=1, + max_tokens=4096 + ) dspy.settings.configure(lm=lm) with open("./examples/data/test.txt", encoding="utf-8-sig") as f: text = f.read() - asyncio.run(nano_entity_extraction(text)) - # extract_entities_dspy(text) \ No newline at end of file + asyncio.run(nano_entity_extraction(text, system_prompt)) diff --git a/examples/using_dspy_entity_extraction.py b/examples/using_dspy_entity_extraction.py new file mode 100644 index 0000000..dd608e4 --- /dev/null +++ b/examples/using_dspy_entity_extraction.py @@ -0,0 +1,149 @@ +import os +from openai import AsyncOpenAI +from dotenv import load_dotenv +import logging +import numpy as np +import dspy +from sentence_transformers import SentenceTransformer +from nano_graphrag import GraphRAG, QueryParam +from nano_graphrag._llm import gpt_4o_mini_complete +from nano_graphrag._storage import HNSWVectorStorage +from nano_graphrag.base import BaseKVStorage +from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("nano-graphrag").setLevel(logging.DEBUG) + +WORKING_DIR = "./nano_graphrag_cache_using_hnsw_as_vectorDB" + +load_dotenv() + + +EMBED_MODEL = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu" +) + + +@wrap_embedding_func_with_attrs( + embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(), + max_token_size=EMBED_MODEL.max_seq_length, +) +async def local_embedding(texts: list[str]) -> np.ndarray: + return EMBED_MODEL.encode(texts, normalize_embeddings=True) + + +async def deepseepk_model_if_cache( + prompt, model: str = "deepseek-chat", system_prompt=None, history_messages=[], **kwargs +) -> str: + openai_async_client = AsyncOpenAI( + api_key=os.environ.get("DEEPSEEK_API_KEY"), base_url="https://api.deepseek.com" + ) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Get the cached response if having------------------- + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + # ----------------------------------------------------- + + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + # Cache the response if having------------------- + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + # ----------------------------------------------------- + return response.choices[0].message.content + + + +def remove_if_exist(file): + if os.path.exists(file): + os.remove(file) + + +def insert(): + from time import time + + with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: + # with open("./examples/data/test.txt", encoding="utf-8-sig") as f: + FAKE_TEXT = f.read() + + remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json") + remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json") + remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml") + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=HNSWVectorStorage, + vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, + best_model_max_async=10, + cheap_model_max_async=10, + best_model_func=deepseepk_model_if_cache, + cheap_model_func=deepseepk_model_if_cache, + embedding_func=local_embedding, + ) + start = time() + rag.insert(FAKE_TEXT) + print("indexing time:", time() - start) + + +def query(): + rag = GraphRAG( + working_dir=WORKING_DIR, + enable_llm_cache=True, + vector_db_storage_cls=HNSWVectorStorage, + vector_db_storage_cls_kwargs={"max_elements": 1000000, "ef_search": 200, "M": 50}, + best_model_max_token_size=8196, + cheap_model_max_token_size=8196, + best_model_max_async=4, + cheap_model_max_async=4, + best_model_func=gpt_4o_mini_complete, + cheap_model_func=gpt_4o_mini_complete, + embedding_func=local_embedding, + + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="global") + ) + ) + print( + rag.query( + "What are the top themes in this story?", param=QueryParam(mode="local") + ) + ) + + +if __name__ == "__main__": + system_prompt = """ + You are a world-class AI system, capable of complex reasoning and reflection. + Reason through the query, and then provide your final response. + If you detect that you made a mistake in your reasoning at any point, correct yourself. + Think carefully. + """ + lm = dspy.OpenAI( + model="deepseek-chat", + model_type="chat", + api_key=os.environ["DEEPSEEK_API_KEY"], + base_url=os.environ["DEEPSEEK_BASE_URL"], + system_prompt=system_prompt, + temperature=0.3, + top_p=1, + max_tokens=4096 + ) + dspy.settings.configure(lm=lm) + insert() + query() diff --git a/examples/using_hnsw_as_vectorDB.py b/examples/using_hnsw_as_vectorDB.py index 41d68d7..8914da5 100644 --- a/examples/using_hnsw_as_vectorDB.py +++ b/examples/using_hnsw_as_vectorDB.py @@ -61,7 +61,6 @@ def insert(): from time import time with open("./tests/mock_data.txt", encoding="utf-8-sig") as f: - # with open("./examples/data/test.txt", encoding="utf-8-sig") as f: FAKE_TEXT = f.read() remove_if_exist(f"{WORKING_DIR}/vdb_entities.json") @@ -109,49 +108,7 @@ def query(): ) ) -import dspy -from openai import OpenAI - - -class DeepSeek(dspy.Module): - def __init__(self, model, api_key, **kwargs): - self.model = model - self.api_key = api_key - self.client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com") - self.provider = "default", - self.history = [] - self.kwargs = { - "temperature": 0.2, - "max_tokens": 2048, - **kwargs - } - - def basic_request(self, prompt, **kwargs): - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} - ], - stream=False, - **self.kwargs - ) - self.history.append({"prompt": prompt, "response": response}) - return response - - def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): - response = self.basic_request(prompt, **kwargs) - completions = [choice.message.content for choice in response.choices] - return completions - - def inspect_history(self, n: int = 1): - if len(self.history) < n: - return self.history - return self.history[-n:] - if __name__ == "__main__": - lm = DeepSeek(model="deepseek-chat", api_key=os.environ["DEEPSEEK_API_KEY"]) - dspy.settings.configure(lm=lm) insert() query() diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index fd815c2..e659cda 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -12,7 +12,7 @@ @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) @@ -69,7 +69,7 @@ async def gpt_4o_mini_complete( @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( - stop=stop_after_attempt(3), + stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), ) diff --git a/nano_graphrag/_op.py b/nano_graphrag/_op.py index e9ee55b..2c2c28e 100644 --- a/nano_graphrag/_op.py +++ b/nano_graphrag/_op.py @@ -3,7 +3,6 @@ import re from typing import Union from collections import Counter, defaultdict -import dspy from ._utils import ( logger, @@ -15,7 +14,7 @@ list_of_list_to_csv, pack_user_ass_to_openai_messages, split_string_by_multi_markers, - truncate_list_by_token_size, + truncate_list_by_token_size ) from .base import ( BaseGraphStorage, @@ -217,195 +216,6 @@ async def _merge_edges_then_upsert( ) -class EntityTypeExtractionSignature(dspy.Signature): - input_text = dspy.InputField(desc="The text to extract entity types from") - entity_types = dspy.OutputField(desc="List of entity types present in the text separated by commas and make sure they are unique and important based on the text's context, e.g. [person, event, technology, mission, organization, location]") - - -class EntityExtractionSignature(dspy.Signature): - input_text = dspy.InputField(desc="The text to extract entities and relationships from") - entities = dspy.OutputField(desc="List of extracted entities with their types, descriptions, and importance scores, make sure descriptions are detailed and specific, and all entity types are included mentioned from the text") - relationships = dspy.OutputField(desc="List of relationships between entities, including detailed descriptions and importance scores") - reasoning = dspy.OutputField(desc="Step-by-step reasoning for entity, relationship, and event extraction, making sure all entities and relationships are mentioned from the text are considered") - - -class EntityGleaningSignature(dspy.Signature): - context = dspy.InputField(desc="The current context including extracted entities and relationships") - entities = dspy.OutputField(desc="List of additional extracted entities with their types, descriptions, and importance scores") - relationships = dspy.OutputField(desc="List of additional relationships between entities, including detailed descriptions and importance scores") - continue_gleaning = dspy.OutputField(desc="Boolean indicating whether to continue gleaning or not") - - -class EntityGleaner(dspy.Module): - def __init__(self, global_config): - super().__init__() - self.gleaner = dspy.ChainOfThought(EntityGleaningSignature) - self.global_config = global_config - - def forward(self, context): - result = self.gleaner(context=context) - return result.entities, result.relationships, result.continue_gleaning - - -class EntityExtractor(dspy.Module): - def __init__(self, global_config): - super().__init__() - self.type_extractor = dspy.TypedPredictor(EntityTypeExtractionSignature) - self.extractor = dspy.ChainOfThought(EntityExtractionSignature) - # self.gleaner = EntityGleaner(global_config) - self.global_config = global_config - self.prompt_template = PROMPTS["entity_extraction"] - self.context_base = dict( - tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], - record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], - completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], - ) - - def forward(self, input_text: str, chunk_key: str): - type_result = self.type_extractor(input_text=input_text) - entity_types = type_result.entity_types.split(',') - formatted_prompt = self.prompt_template.format( - input_text=input_text, - entity_types=entity_types, - **self.context_base - ) - extraction_result = self.extractor(input_text=formatted_prompt) - entities = self.handle_single_entity_extraction(extraction_result.entities, chunk_key) - relationships = self.handle_single_relationship_extraction(extraction_result.relationships, chunk_key) - # context = self.format_context(entities, relationships) - - # for _ in range(self.global_config["entity_extract_max_gleaning"]): - # additional_entities, additional_relationships, continue_gleaning = self.gleaner(context) - # entities.extend(self.handle_single_entity_extraction(additional_entities, chunk_key)) - # relationships.extend(self.handle_single_relationship_extraction(additional_relationships, chunk_key)) - # context = self.format_context(entities, relationships) - - # if not continue_gleaning: - # break - - return entities, relationships - - def format_context(self, entities, relationships): - entity_context = "\n".join([f"{e['entity_name']} ({e['entity_type']}): {e['description']}" for e in entities]) - relationship_context = "\n".join([f"{r['src_id']} -> {r['tgt_id']}: {r['description']}" for r in relationships]) - return f"Entities:\n{entity_context}\n\nRelationships:\n{relationship_context}" - - def handle_single_entity_extraction(self, entities: str, chunk_key: str): - entities_list = entities.split('\n') - extracted_entities = [] - for entity in entities_list: - match = re.match(r'\d+\.\s*\("entity"<\|>"([^"]+)"<\|>"([^"]+)"<\|>"([^"]+)"\)', entity) - if match: - entity_name = clean_str(match.group(1).upper()) - entity_type = clean_str(match.group(2).upper()) - entity_description = clean_str(match.group(3)) - extracted_entities.append(dict( - entity_name=entity_name, - entity_type=entity_type, - description=entity_description, - source_id=chunk_key, - )) - return extracted_entities - - def handle_single_relationship_extraction(self, relationships: str, chunk_key: str): - relationships_list = relationships.split('\n') - extracted_relationships = [] - for relationship in relationships_list: - match = re.match(r'\d+\.\s*\("relationship"<\|>"([^"]+)"<\|>"([^"]+)"<\|>"([^"]+)"<\|>([0-9.]+)\)', relationship) - if match: - source = clean_str(match.group(1).upper()) - target = clean_str(match.group(2).upper()) - edge_description = clean_str(match.group(3)) - importance_score = float(match.group(4)) if is_float_regex(match.group(4)) else 1.0 - extracted_relationships.append(dict( - src_id=source, - tgt_id=target, - weight=importance_score, - description=edge_description, - source_id=chunk_key, - )) - return extracted_relationships - - -async def extract_entities_dspy( - chunks: dict[str, TextChunkSchema], - knwoledge_graph_inst: BaseGraphStorage, - entity_vdb: BaseVectorStorage, - global_config: dict, -) -> Union[BaseGraphStorage, None]: - entity_extractor = EntityExtractor(global_config) - ordered_chunks = list(chunks.items()) - already_processed = 0 - already_entities = 0 - already_relations = 0 - - async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): - nonlocal already_processed, already_entities, already_relations - chunk_key = chunk_key_dp[0] - chunk_dp = chunk_key_dp[1] - content = chunk_dp["content"] - entities, relationships = entity_extractor(input_text=content, chunk_key=chunk_key) - - maybe_nodes = defaultdict(list) - maybe_edges = defaultdict(list) - - for entity in entities: - maybe_nodes[entity["entity_name"]].append(entity) - already_entities += 1 - - for relationship in relationships: - maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(relationship) - already_relations += 1 - - already_processed += 1 - now_ticks = PROMPTS["process_tickers"][ - already_processed % len(PROMPTS["process_tickers"]) - ] - print( - f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", - end="", - flush=True, - ) - return dict(maybe_nodes), dict(maybe_edges) - - results = await asyncio.gather( - *[_process_single_content(c) for c in ordered_chunks] - ) - print() # clear the progress bar - maybe_nodes = defaultdict(list) - maybe_edges = defaultdict(list) - for m_nodes, m_edges in results: - for k, v in m_nodes.items(): - maybe_nodes[k].extend(v) - for k, v in m_edges.items(): - maybe_edges[k].extend(v) - all_entities_data = await asyncio.gather( - *[ - _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) - for k, v in maybe_nodes.items() - ] - ) - await asyncio.gather( - *[ - _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) - for k, v in maybe_edges.items() - ] - ) - if not len(all_entities_data): - logger.warning("Didn't extract any entities, maybe your LLM is not working") - return None - if entity_vdb is not None: - data_for_vdb = { - compute_mdhash_id(dp["entity_name"], prefix="ent-"): { - "content": dp["entity_name"] + dp["description"], - "entity_name": dp["entity_name"], - } - for dp in all_entities_data - } - await entity_vdb.upsert(data_for_vdb) - return knwoledge_graph_inst - - async def extract_entities( chunks: dict[str, TextChunkSchema], knwoledge_graph_inst: BaseGraphStorage, diff --git a/nano_graphrag/entity_extraction/__init__.py b/nano_graphrag/entity_extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nano_graphrag/entity_extraction/extract.py b/nano_graphrag/entity_extraction/extract.py new file mode 100644 index 0000000..d7ca129 --- /dev/null +++ b/nano_graphrag/entity_extraction/extract.py @@ -0,0 +1,94 @@ +import asyncio +from collections import defaultdict + +from nano_graphrag._storage import BaseGraphStorage +from nano_graphrag.base import ( + BaseGraphStorage, + BaseVectorStorage, + TextChunkSchema, +) +from nano_graphrag.prompt import PROMPTS +from nano_graphrag._utils import logger, compute_mdhash_id +from nano_graphrag.entity_extraction.module import EntityRelationshipExtractor +from nano_graphrag._op import _merge_edges_then_upsert, _merge_nodes_then_upsert + + +async def extract_entities_dspy( + chunks: dict[str, TextChunkSchema], + knwoledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + global_config: dict, +) -> BaseGraphStorage | None: + entity_extractor = EntityRelationshipExtractor() + ordered_chunks = list(chunks.items()) + already_processed = 0 + already_entities = 0 + already_relations = 0 + + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + nonlocal already_processed, already_entities, already_relations + chunk_key = chunk_key_dp[0] + chunk_dp = chunk_key_dp[1] + content = chunk_dp["content"] + entities, relationships = await asyncio.to_thread( + entity_extractor, input_text=content, chunk_key=chunk_key + ) + + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + + for entity in entities: + maybe_nodes[entity["entity_name"]].append(entity) + already_entities += 1 + + for relationship in relationships: + maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append(relationship) + already_relations += 1 + + already_processed += 1 + now_ticks = PROMPTS["process_tickers"][ + already_processed % len(PROMPTS["process_tickers"]) + ] + print( + f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r", + end="", + flush=True, + ) + return dict(maybe_nodes), dict(maybe_edges) + + results = await asyncio.gather( + *[_process_single_content(c) for c in ordered_chunks] + ) + print() + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + for m_nodes, m_edges in results: + for k, v in m_nodes.items(): + maybe_nodes[k].extend(v) + for k, v in m_edges.items(): + maybe_edges[k].extend(v) + all_entities_data = await asyncio.gather( + *[ + _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config) + for k, v in maybe_nodes.items() + ] + ) + await asyncio.gather( + *[ + _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config) + for k, v in maybe_edges.items() + ] + ) + if not len(all_entities_data): + logger.warning("Didn't extract any entities, maybe your LLM is not working") + return None + if entity_vdb is not None: + data_for_vdb = { + compute_mdhash_id(dp["entity_name"], prefix="ent-"): { + "content": dp["entity_name"] + dp["description"], + "entity_name": dp["entity_name"], + } + for dp in all_entities_data + } + await entity_vdb.upsert(data_for_vdb) + return knwoledge_graph_inst diff --git a/nano_graphrag/entity_extraction/module.py b/nano_graphrag/entity_extraction/module.py new file mode 100644 index 0000000..c7828d5 --- /dev/null +++ b/nano_graphrag/entity_extraction/module.py @@ -0,0 +1,73 @@ +import json +import re +import dspy +from nano_graphrag.prompt import PROMPTS +from nano_graphrag._utils import clean_str +from nano_graphrag.entity_extraction.signature import ( + EntityTypeExtraction, + EntityExtraction, + RelationshipExtraction, +) + + +class EntityRelationshipExtractor(dspy.Module): + def __init__(self): + super().__init__() + self.type_extractor = dspy.ChainOfThought(EntityTypeExtraction) + self.entity_extractor = dspy.ChainOfThought(EntityExtraction) + self.relationship_extractor = dspy.ChainOfThought(RelationshipExtraction) + self.prompt_template = PROMPTS["entity_extraction"] + self.context_base = dict( + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + ) + + def forward(self, input_text: str, chunk_key: str) -> tuple[list[dict], list[dict]]: + type_result = self.type_extractor(input_text=input_text) + formatted_prompt = self.prompt_template.format( + input_text=input_text, + entity_types=type_result.entity_types, + **self.context_base + ) + entity_result = self.entity_extractor(input_text=formatted_prompt, entity_types=type_result.entity_types) + relationship_result = self.relationship_extractor(input_text=formatted_prompt, entities=entity_result.entities) + parsed_entities = self.handle_single_entity_extraction(entity_result.entities, chunk_key) + parsed_relationships = self.handle_single_relationship_extraction(relationship_result.relationships, chunk_key) + return parsed_entities, parsed_relationships + + def handle_single_entity_extraction(self, entities: str, chunk_key: str) -> list[dict]: + entities = re.sub(r'^\d+\.\s*', '', entities, flags=re.MULTILINE) + entities = entities.replace(PROMPTS["DEFAULT_COMPLETION_DELIMITER"], '').strip() + entity_strings = re.findall(r'\{[^}]+\}', entities) + extracted_entities = [] + + for entity_str in entity_strings: + entity = json.loads(entity_str) + extracted_entities.append({ + "source_id": chunk_key, + "entity_name": clean_str(entity["name"].upper()), + "entity_type": clean_str(entity["type"].upper()), + "description": clean_str(entity["description"]), + "importance_score": float(entity["importance_score"]), + }) + + return extracted_entities + + def handle_single_relationship_extraction(self, relationships: str, chunk_key: str) -> list[dict]: + relationships = re.sub(r'^\d+\.\s*', '', relationships, flags=re.MULTILINE) + relationships = relationships.replace(PROMPTS["DEFAULT_COMPLETION_DELIMITER"], '').strip() + relationship_strings = re.findall(r'\{[^}]+\}', relationships) + extracted_relationships = [] + + for relationship_str in relationship_strings: + relationship = json.loads(relationship_str) + extracted_relationships.append({ + "source_id": chunk_key, + "src_id": clean_str(relationship["source"].upper()), + "tgt_id": clean_str(relationship["target"].upper()), + "description": clean_str(relationship["description"]), + "weight": float(relationship["importance_score"]), + }) + + return extracted_relationships diff --git a/nano_graphrag/entity_extraction/signature.py b/nano_graphrag/entity_extraction/signature.py new file mode 100644 index 0000000..3f72ccf --- /dev/null +++ b/nano_graphrag/entity_extraction/signature.py @@ -0,0 +1,75 @@ +import dspy + + +class EntityTypeExtraction(dspy.Signature): + """Signature for extracting entity types from input text.""" + + input_text = dspy.InputField(desc="The text to extract entity types from.") + entity_types = dspy.OutputField( + desc=""" + List of entity types present in the text separated by commas and make sure they are single word, unique, + and important based on the text's context. + For instance: [person, event, technology, mission, organization, location]. + """ + ) + + +class EntityExtraction(dspy.Signature): + """Signature for extracting entities from input text.""" + + input_text = dspy.InputField(desc="The text to extract entities from.") + entity_types = dspy.InputField(desc="List of entity types to consider.") + entities = dspy.OutputField( + desc=""" + List of extracted entities including their types, descriptions, and importance scores (0-1, with 1 being most important). + Format should be a list of dictionaries like the following: + [ + { + "name": "Entity name", + "type": "Entity type", + "description": "Detailed and specific description", + "importance_score": float (0.0 to 1.0) + } + ] + Make sure descriptions are detailed and specific, and all entity types are included mentioned from the text. + Ensure that all fields in the above format are present for every single entity dictionary. + Entities must have an importance score greater than 0.5, which means both primary and secondary importance entities will be extracted. + """ + ) + + +class RelationshipExtraction(dspy.Signature): + """Signature for extracting relationships between entities from input text.""" + + input_text = dspy.InputField(desc="The text to extract relationships from.") + entities = dspy.InputField( + desc=""" + List of extracted entities including their types, descriptions, and importance scores (0-1, with 1 being most important). + Format should be a list of dictionaries like the following: + [ + { + "name": "Entity name", + "type": "Entity type", + "description": "Detailed and specific description", + "importance_score": float (0.0 to 1.0) + } + ] + """ + ) + relationships = dspy.OutputField( + desc=""" + List of relationships between entities, including detailed descriptions and importance scores (0-1, with 1 being most important). + Format should be a list of dictionaries like the following: + [ + { + "source": "Source entity name", + "target": "Target entity name", + "description": "Detailed description of the relationship", + "importance_score": float (0.0 to 1.0) + } + ] + Make sure relationships are detailed and specific. + Ensure that all fields in the above format are present for every single relationship dictionary. + """ + ) + diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index 88aa678..3f663db 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -9,7 +9,7 @@ from ._llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding from ._op import ( chunking_by_token_size, - extract_entities_dspy, + extract_entities, generate_community_report, local_query, global_query, @@ -276,7 +276,7 @@ async def ainsert(self, string_or_strings): # ---------- extract/summary entity and upsert to graph logger.info("[Entity Extraction]...") - maybe_new_kg = await extract_entities_dspy( + maybe_new_kg = await extract_entities( inserting_chunks, knwoledge_graph_inst=self.chunk_entity_relation_graph, entity_vdb=self.entities_vdb,