diff --git a/.coveragerc b/.coveragerc index 2f7552f..177da7e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,7 @@ exclude_lines = # Don't complain if tests don't hit defensive assertion code: raise NotImplementedError - logger. \ No newline at end of file + logger. +omit = + # Don't have a nice github action for neo4j now, so skip this file: + nano_graphrag/_storage/gdb_neo4j.py \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e893b30..e975b44 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,8 @@ jobs: run: | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - name: Build and Test + env: + NANO_GRAPHRAG_TEST_IGNORE_NEO4J: true run: | python -m pytest -o log_cli=true -o log_cli_level="INFO" --cov=nano_graphrag --cov-report=xml -v ./ - name: Check codecov file diff --git a/.gitignore b/.gitignore index c5b0583..801875a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,7 @@ # Created by https://www.toptal.com/developers/gitignore/api/python # Edit at https://www.toptal.com/developers/gitignore?templates=python test_cache.json -run_test.py -run_test_zh.py +run_test*.py nano_graphrag_cache*/ *.txt examples/benchmarks/fixtures/ diff --git a/docs/use_neo4j_for_graphrag.md b/docs/use_neo4j_for_graphrag.md new file mode 100644 index 0000000..3f92d26 --- /dev/null +++ b/docs/use_neo4j_for_graphrag.md @@ -0,0 +1,27 @@ +1. Install [Neo4j](https://neo4j.com/docs/operations-manual/current/installation/) +2. Install Neo4j GDS (graph data science) [plugin](https://neo4j.com/docs/graph-data-science/current/installation/neo4j-server/) +3. Start neo4j server +4. Get the `NEO4J_URL`, `NEO4J_USER` and `NEO4J_PASSWORD` + - By default, `NEO4J_URL` is `neo4j://localhost:7687` , `NEO4J_USER` is `neo4j` and `NEO4J_PASSWORD` is `neo4j` + +Pass your neo4j instance to `GraphRAG`: + +```python +from nano_graphrag import GraphRAG +from nano_graphrag._storage import Neo4jStorage + +neo4j_config = { + "neo4j_url": os.environ.get("NEO4J_URL", "neo4j://localhost:7687"), + "neo4j_auth": ( + os.environ.get("NEO4J_USER", "neo4j"), + os.environ.get("NEO4J_PASSWORD", "neo4j"), + ) +} +GraphRAG( + graph_storage_cls=Neo4jStorage, + addon_params=neo4j_config, +) +``` + + + diff --git a/examples/no_openai_key_at_all.py b/examples/no_openai_key_at_all.py index f1624e0..1fce788 100644 --- a/examples/no_openai_key_at_all.py +++ b/examples/no_openai_key_at_all.py @@ -34,6 +34,7 @@ async def ollama_model_if_cache( ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] diff --git a/examples/using_ollama_as_llm.py b/examples/using_ollama_as_llm.py index a358b28..e067212 100644 --- a/examples/using_ollama_as_llm.py +++ b/examples/using_ollama_as_llm.py @@ -18,6 +18,7 @@ async def ollama_model_if_cache( ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] diff --git a/examples/using_ollama_as_llm_and_embedding.py b/examples/using_ollama_as_llm_and_embedding.py index 97614b9..44d669d 100644 --- a/examples/using_ollama_as_llm_and_embedding.py +++ b/examples/using_ollama_as_llm_and_embedding.py @@ -20,11 +20,13 @@ EMBEDDING_MODEL_DIM = 768 EMBEDDING_MODEL_MAX_TOKENS = 8192 + async def ollama_model_if_cache( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: # remove kwargs that are not supported by ollama kwargs.pop("max_tokens", None) + kwargs.pop("response_format", None) ollama_client = ollama.AsyncClient() messages = [] @@ -98,20 +100,21 @@ def insert(): # rag = GraphRAG(working_dir=WORKING_DIR, enable_llm_cache=True) # rag.insert(FAKE_TEXT[half_len:]) + # We're using Ollama to generate embeddings for the BGE model @wrap_embedding_func_with_attrs( - embedding_dim= EMBEDDING_MODEL_DIM, - max_token_size= EMBEDDING_MODEL_MAX_TOKENS, + embedding_dim=EMBEDDING_MODEL_DIM, + max_token_size=EMBEDDING_MODEL_MAX_TOKENS, ) - -async def ollama_embedding(texts :list[str]) -> np.ndarray: +async def ollama_embedding(texts: list[str]) -> np.ndarray: embed_text = [] for text in texts: - data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text) - embed_text.append(data["embedding"]) - + data = ollama.embeddings(model=EMBEDDING_MODEL, prompt=text) + embed_text.append(data["embedding"]) + return embed_text + if __name__ == "__main__": insert() query() diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index 068fa77..f658234 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -13,6 +13,23 @@ from ._utils import compute_args_hash, wrap_embedding_func_with_attrs from .base import BaseKVStorage +global_openai_async_client = None +global_azure_openai_async_client = None + + +def get_openai_async_client_instance(): + global global_openai_async_client + if global_openai_async_client is None: + global_openai_async_client = AsyncOpenAI() + return global_openai_async_client + + +def get_azure_openai_async_client_instance(): + global global_azure_openai_async_client + if global_azure_openai_async_client is None: + global_azure_openai_async_client = AsyncAzureOpenAI() + return global_azure_openai_async_client + @retry( stop=stop_after_attempt(5), @@ -22,7 +39,7 @@ async def openai_complete_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - openai_async_client = AsyncOpenAI() + openai_async_client = get_openai_async_client_instance() hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -78,7 +95,7 @@ async def gpt_4o_mini_complete( retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) async def openai_embedding(texts: list[str]) -> np.ndarray: - openai_async_client = AsyncOpenAI() + openai_async_client = get_openai_async_client_instance() response = await openai_async_client.embeddings.create( model="text-embedding-3-small", input=texts, encoding_format="float" ) @@ -93,7 +110,7 @@ async def openai_embedding(texts: list[str]) -> np.ndarray: async def azure_openai_complete_if_cache( deployment_name, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: - azure_openai_client = AsyncAzureOpenAI() + azure_openai_client = get_azure_openai_async_client_instance() hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) messages = [] if system_prompt: @@ -154,11 +171,7 @@ async def azure_gpt_4o_mini_complete( retry=retry_if_exception_type((RateLimitError, APIConnectionError)), ) async def azure_openai_embedding(texts: list[str]) -> np.ndarray: - azure_openai_client = AsyncAzureOpenAI( - api_key=os.environ.get("API_KEY_EMB"), - api_version=os.environ.get("API_VERSION_EMB"), - azure_endpoint=os.environ.get("AZURE_ENDPOINT_EMB"), - ) + azure_openai_client = get_azure_openai_async_client_instance() response = await azure_openai_client.embeddings.create( model="text-embedding-3-small", input=texts, encoding_format="float" ) diff --git a/nano_graphrag/_storage/__init__.py b/nano_graphrag/_storage/__init__.py new file mode 100644 index 0000000..c8184ab --- /dev/null +++ b/nano_graphrag/_storage/__init__.py @@ -0,0 +1,5 @@ +from .gdb_networkx import NetworkXStorage +from .gdb_neo4j import Neo4jStorage +from .vdb_hnswlib import HNSWVectorStorage +from .vdb_nanovectordb import NanoVectorDBStorage +from .kv_json import JsonKVStorage diff --git a/nano_graphrag/_storage/gdb_neo4j.py b/nano_graphrag/_storage/gdb_neo4j.py new file mode 100644 index 0000000..e45634a --- /dev/null +++ b/nano_graphrag/_storage/gdb_neo4j.py @@ -0,0 +1,330 @@ +import json +import asyncio +from collections import defaultdict +from neo4j import AsyncGraphDatabase +from dataclasses import dataclass +from typing import Union +from ..base import BaseGraphStorage, SingleCommunitySchema +from .._utils import logger +from ..prompt import GRAPH_FIELD_SEP + +neo4j_lock = asyncio.Lock() + + +def make_path_idable(path): + return path.replace(".", "_").replace("/", "__").replace("-", "_") + + +@dataclass +class Neo4jStorage(BaseGraphStorage): + def __post_init__(self): + self.neo4j_url = self.global_config["addon_params"].get("neo4j_url", None) + self.neo4j_auth = self.global_config["addon_params"].get("neo4j_auth", None) + self.namespace = ( + f"{make_path_idable(self.global_config['working_dir'])}__{self.namespace}" + ) + logger.info(f"Using the label {self.namespace} for Neo4j as identifier") + if self.neo4j_url is None or self.neo4j_auth is None: + raise ValueError("Missing neo4j_url or neo4j_auth in addon_params") + self.async_driver = AsyncGraphDatabase.driver( + self.neo4j_url, auth=self.neo4j_auth + ) + + # async def create_database(self): + # async with self.async_driver.session() as session: + # try: + # constraints = await session.run("SHOW CONSTRAINTS") + # # TODO I don't know why CREATE CONSTRAINT IF NOT EXISTS still trigger error + # # so have to check if the constrain exists + # constrain_exists = False + + # async for record in constraints: + # if ( + # self.namespace in record["labelsOrTypes"] + # and "id" in record["properties"] + # and record["type"] == "UNIQUENESS" + # ): + # constrain_exists = True + # break + # if not constrain_exists: + # await session.run( + # f"CREATE CONSTRAINT FOR (n:{self.namespace}) REQUIRE n.id IS UNIQUE" + # ) + # logger.info(f"Add constraint for namespace: {self.namespace}") + + # except Exception as e: + # logger.error(f"Error accessing or setting up the database: {str(e)}") + # raise + + async def _init_workspace(self): + await self.async_driver.verify_authentication() + await self.async_driver.verify_connectivity() + # TODOLater: create database if not exists always cause an error when async + # await self.create_database() + + async def index_start_callback(self): + logger.info("Init Neo4j workspace") + await self._init_workspace() + + async def has_node(self, node_id: str) -> bool: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN COUNT(n) > 0 AS exists", + node_id=node_id, + ) + record = await result.single() + return record["exists"] if record else False + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) " + "WHERE s.id = $source_id AND t.id = $target_id " + "RETURN COUNT(r) > 0 AS exists", + source_id=source_node_id, + target_id=target_node_id, + ) + record = await result.single() + return record["exists"] if record else False + + async def node_degree(self, node_id: str) -> int: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (n:{self.namespace}) WHERE n.id = $node_id " + f"RETURN COUNT {{(n)-[]-(:{self.namespace})}} AS degree", + node_id=node_id, + ) + record = await result.single() + return record["degree"] if record else 0 + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (s:{self.namespace}), (t:{self.namespace}) " + "WHERE s.id = $src_id AND t.id = $tgt_id " + f"RETURN COUNT {{(s)-[]-(:{self.namespace})}} + COUNT {{(t)-[]-(:{self.namespace})}} AS degree", + src_id=src_id, + tgt_id=tgt_id, + ) + record = await result.single() + return record["degree"] if record else 0 + + async def get_node(self, node_id: str) -> Union[dict, None]: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (n:{self.namespace}) WHERE n.id = $node_id RETURN properties(n) AS node_data", + node_id=node_id, + ) + record = await result.single() + raw_node_data = record["node_data"] if record else None + if raw_node_data is None: + return None + raw_node_data["clusters"] = json.dumps( + [ + { + "level": index, + "cluster": cluster_id, + } + for index, cluster_id in enumerate( + raw_node_data.get("communityIds", []) + ) + ] + ) + return raw_node_data + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> Union[dict, None]: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) " + "WHERE s.id = $source_id AND t.id = $target_id " + "RETURN properties(r) AS edge_data", + source_id=source_node_id, + target_id=target_node_id, + ) + record = await result.single() + return record["edge_data"] if record else None + + async def get_node_edges( + self, source_node_id: str + ) -> Union[list[tuple[str, str]], None]: + async with self.async_driver.session() as session: + result = await session.run( + f"MATCH (s:{self.namespace})-[r]->(t:{self.namespace}) WHERE s.id = $source_id " + "RETURN s.id AS source, t.id AS target", + source_id=source_node_id, + ) + edges = [] + async for record in result: + edges.append((record["source"], record["target"])) + return edges + + async def upsert_node(self, node_id: str, node_data: dict[str, str]): + node_type = node_data.get("entity_type", "UNKNOWN").strip('"') + async with self.async_driver.session() as session: + await session.run( + f"MERGE (n:{self.namespace}:{node_type} {{id: $node_id}}) " + "SET n += $node_data", + node_id=node_id, + node_data=node_data, + ) + + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + edge_data.setdefault("weight", 0.0) + async with self.async_driver.session() as session: + await session.run( + f"MATCH (s:{self.namespace}), (t:{self.namespace}) " + "WHERE s.id = $source_id AND t.id = $target_id " + "MERGE (s)-[r:RELATED]->(t) " # Added relationship type 'RELATED' + "SET r += $edge_data", + source_id=source_node_id, + target_id=target_node_id, + edge_data=edge_data, + ) + + async def clustering(self, algorithm: str): + if algorithm != "leiden": + raise ValueError( + f"Clustering algorithm {algorithm} not supported in Neo4j implementation" + ) + + random_seed = self.global_config["graph_cluster_seed"] + max_level = self.global_config["max_graph_cluster_size"] + async with self.async_driver.session() as session: + try: + # Project the graph with undirected relationships + await session.run( + f""" + CALL gds.graph.project( + 'graph_{self.namespace}', + ['{self.namespace}'], + {{ + RELATED: {{ + orientation: 'UNDIRECTED', + properties: ['weight'] + }} + }} + ) + """ + ) + + # Run Leiden algorithm + result = await session.run( + f""" + CALL gds.leiden.write( + 'graph_{self.namespace}', + {{ + writeProperty: 'communityIds', + includeIntermediateCommunities: True, + relationshipWeightProperty: "weight", + maxLevels: {max_level}, + tolerance: 0.0001, + gamma: 1.0, + theta: 0.01, + randomSeed: {random_seed} + }} + ) + YIELD communityCount, modularities; + """ + ) + result = await result.single() + community_count: int = result["communityCount"] + modularities = result["modularities"] + logger.info( + f"Performed graph clustering with {community_count} communities and modularities {modularities}" + ) + finally: + # Drop the projected graph + await session.run(f"CALL gds.graph.drop('graph_{self.namespace}')") + + async def community_schema(self) -> dict[str, SingleCommunitySchema]: + results = defaultdict( + lambda: dict( + level=None, + title=None, + edges=set(), + nodes=set(), + chunk_ids=set(), + occurrence=0.0, + sub_communities=[], + ) + ) + + async with self.async_driver.session() as session: + # Fetch community data + result = await session.run( + f""" + MATCH (n:{self.namespace}) + WITH n, n.communityIds AS communityIds, [(n)-[]-(m:{self.namespace}) | m.id] AS connected_nodes + RETURN n.id AS node_id, n.source_id AS source_id, + communityIds AS cluster_key, + connected_nodes + """ + ) + + # records = await result.fetch() + + max_num_ids = 0 + async for record in result: + for index, c_id in enumerate(record["cluster_key"]): + node_id = str(record["node_id"]) + source_id = record["source_id"] + level = index + cluster_key = str(c_id) + connected_nodes = record["connected_nodes"] + + results[cluster_key]["level"] = level + results[cluster_key]["title"] = f"Cluster {cluster_key}" + results[cluster_key]["nodes"].add(node_id) + results[cluster_key]["edges"].update( + [ + tuple(sorted([node_id, str(connected)])) + for connected in connected_nodes + if connected != node_id + ] + ) + chunk_ids = source_id.split(GRAPH_FIELD_SEP) + results[cluster_key]["chunk_ids"].update(chunk_ids) + max_num_ids = max( + max_num_ids, len(results[cluster_key]["chunk_ids"]) + ) + + # Process results + for k, v in results.items(): + v["edges"] = [list(e) for e in v["edges"]] + v["nodes"] = list(v["nodes"]) + v["chunk_ids"] = list(v["chunk_ids"]) + v["occurrence"] = len(v["chunk_ids"]) / max_num_ids + + # Compute sub-communities (this is a simplified approach) + for cluster in results.values(): + cluster["sub_communities"] = [ + sub_key + for sub_key, sub_cluster in results.items() + if sub_cluster["level"] > cluster["level"] + and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"])) + ] + + return dict(results) + + async def index_done_callback(self): + await self.async_driver.close() + + async def _debug_delete_all_node_edges(self): + async with self.async_driver.session() as session: + try: + # Delete all relationships in the namespace + await session.run(f"MATCH (n:{self.namespace})-[r]-() DELETE r") + + # Delete all nodes in the namespace + await session.run(f"MATCH (n:{self.namespace}) DELETE n") + + logger.info( + f"All nodes and edges in namespace '{self.namespace}' have been deleted." + ) + except Exception as e: + logger.error(f"Error deleting nodes and edges: {str(e)}") + raise diff --git a/nano_graphrag/_storage.py b/nano_graphrag/_storage/gdb_networkx.py similarity index 52% rename from nano_graphrag/_storage.py rename to nano_graphrag/_storage/gdb_networkx.py index 831ca26..e29bf3e 100644 --- a/nano_graphrag/_storage.py +++ b/nano_graphrag/_storage/gdb_networkx.py @@ -1,254 +1,18 @@ -import asyncio import html import json import os from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Union, cast -import pickle -import hnswlib import networkx as nx import numpy as np -from nano_vectordb import NanoVectorDB -import xxhash -from ._utils import load_json, logger, write_json -from .base import ( +from .._utils import logger +from ..base import ( BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, SingleCommunitySchema, ) -from .prompt import GRAPH_FIELD_SEP - - -@dataclass -class JsonKVStorage(BaseKVStorage): - def __post_init__(self): - working_dir = self.global_config["working_dir"] - self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") - self._data = load_json(self._file_name) or {} - logger.info(f"Load KV {self.namespace} with {len(self._data)} data") - - async def all_keys(self) -> list[str]: - return list(self._data.keys()) - - async def index_done_callback(self): - write_json(self._data, self._file_name) - - async def get_by_id(self, id): - return self._data.get(id, None) - - async def get_by_ids(self, ids, fields=None): - if fields is None: - return [self._data.get(id, None) for id in ids] - return [ - ( - {k: v for k, v in self._data[id].items() if k in fields} - if self._data.get(id, None) - else None - ) - for id in ids - ] - - async def filter_keys(self, data: list[str]) -> set[str]: - return set([s for s in data if s not in self._data]) - - async def upsert(self, data: dict[str, dict]): - self._data.update(data) - - async def drop(self): - self._data = {} - - -@dataclass -class NanoVectorDBStorage(BaseVectorStorage): - cosine_better_than_threshold: float = 0.2 - - def __post_init__(self): - - self._client_file_name = os.path.join( - self.global_config["working_dir"], f"vdb_{self.namespace}.json" - ) - self._max_batch_size = self.global_config["embedding_batch_num"] - self._client = NanoVectorDB( - self.embedding_func.embedding_dim, storage_file=self._client_file_name - ) - self.cosine_better_than_threshold = self.global_config.get( - "query_better_than_threshold", self.cosine_better_than_threshold - ) - - async def upsert(self, data: dict[str, dict]): - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not len(data): - logger.warning("You insert an empty data to vector DB") - return [] - list_data = [ - { - "__id__": k, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - embeddings_list = await asyncio.gather( - *[self.embedding_func(batch) for batch in batches] - ) - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - results = self._client.upsert(datas=list_data) - return results - - async def query(self, query: str, top_k=5): - embedding = await self.embedding_func([query]) - embedding = embedding[0] - results = self._client.query( - query=embedding, - top_k=top_k, - better_than_threshold=self.cosine_better_than_threshold, - ) - results = [ - {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results - ] - return results - - async def index_done_callback(self): - self._client.save() - - -@dataclass -class HNSWVectorStorage(BaseVectorStorage): - ef_construction: int = 100 - M: int = 16 - max_elements: int = 1000000 - ef_search: int = 50 - num_threads: int = -1 - _index: Any = field(init=False) - _metadata: dict[str, dict] = field(default_factory=dict) - _current_elements: int = 0 - - def __post_init__(self): - self._index_file_name = os.path.join( - self.global_config["working_dir"], f"{self.namespace}_hnsw.index" - ) - self._metadata_file_name = os.path.join( - self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl" - ) - self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100) - - hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {}) - self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction) - self.M = hnsw_params.get("M", self.M) - self.max_elements = hnsw_params.get("max_elements", self.max_elements) - self.ef_search = hnsw_params.get("ef_search", self.ef_search) - self.num_threads = hnsw_params.get("num_threads", self.num_threads) - self._index = hnswlib.Index( - space="cosine", dim=self.embedding_func.embedding_dim - ) - - if os.path.exists(self._index_file_name) and os.path.exists( - self._metadata_file_name - ): - self._index.load_index( - self._index_file_name, max_elements=self.max_elements - ) - with open(self._metadata_file_name, "rb") as f: - self._metadata, self._current_elements = pickle.load(f) - logger.info( - f"Loaded existing index for {self.namespace} with {self._current_elements} elements" - ) - else: - self._index.init_index( - max_elements=self.max_elements, - ef_construction=self.ef_construction, - M=self.M, - ) - self._index.set_ef(self.ef_search) - self._metadata = {} - self._current_elements = 0 - logger.info(f"Created new index for {self.namespace}") - - async def upsert(self, data: dict[str, dict]) -> np.ndarray: - logger.info(f"Inserting {len(data)} vectors to {self.namespace}") - if not data: - logger.warning("You insert an empty data to vector DB") - return [] - - if self._current_elements + len(data) > self.max_elements: - raise ValueError( - f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}" - ) - - list_data = [ - { - "id": k, - **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batch_size = min(self._embedding_batch_num, len(contents)) - embeddings = np.concatenate( - await asyncio.gather( - *[ - self.embedding_func(contents[i : i + batch_size]) - for i in range(0, len(contents), batch_size) - ] - ) - ) - - ids = np.fromiter( - (xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data), - dtype=np.uint32, - count=len(list_data), - ) - self._metadata.update( - { - id_int: { - k: v for k, v in d.items() if k in self.meta_fields or k == "id" - } - for id_int, d in zip(ids, list_data) - } - ) - self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads) - self._current_elements = self._index.get_current_count() - return ids - - async def query(self, query: str, top_k: int = 5) -> list[dict]: - if self._current_elements == 0: - return [] - - top_k = min(top_k, self._current_elements) - - if top_k > self.ef_search: - logger.warning( - f"Setting ef_search to {top_k} because top_k is larger than ef_search" - ) - self._index.set_ef(top_k) - - embedding = await self.embedding_func([query]) - labels, distances = self._index.knn_query( - data=embedding[0], k=top_k, num_threads=self.num_threads - ) - - return [ - { - **self._metadata.get(label, {}), - "distance": distance, - "similarity": 1 - distance, - } - for label, distance in zip(labels[0], distances[0]) - ] - - async def index_done_callback(self): - self._index.save_index(self._index_file_name) - with open(self._metadata_file_name, "wb") as f: - pickle.dump((self._metadata, self._current_elements), f) +from ..prompt import GRAPH_FIELD_SEP @dataclass diff --git a/nano_graphrag/_storage/kv_json.py b/nano_graphrag/_storage/kv_json.py new file mode 100644 index 0000000..b802f26 --- /dev/null +++ b/nano_graphrag/_storage/kv_json.py @@ -0,0 +1,46 @@ +import os +from dataclasses import dataclass + +from .._utils import load_json, logger, write_json +from ..base import ( + BaseKVStorage, +) + + +@dataclass +class JsonKVStorage(BaseKVStorage): + def __post_init__(self): + working_dir = self.global_config["working_dir"] + self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json") + self._data = load_json(self._file_name) or {} + logger.info(f"Load KV {self.namespace} with {len(self._data)} data") + + async def all_keys(self) -> list[str]: + return list(self._data.keys()) + + async def index_done_callback(self): + write_json(self._data, self._file_name) + + async def get_by_id(self, id): + return self._data.get(id, None) + + async def get_by_ids(self, ids, fields=None): + if fields is None: + return [self._data.get(id, None) for id in ids] + return [ + ( + {k: v for k, v in self._data[id].items() if k in fields} + if self._data.get(id, None) + else None + ) + for id in ids + ] + + async def filter_keys(self, data: list[str]) -> set[str]: + return set([s for s in data if s not in self._data]) + + async def upsert(self, data: dict[str, dict]): + self._data.update(data) + + async def drop(self): + self._data = {} diff --git a/nano_graphrag/_storage/vdb_hnswlib.py b/nano_graphrag/_storage/vdb_hnswlib.py new file mode 100644 index 0000000..3e98c95 --- /dev/null +++ b/nano_graphrag/_storage/vdb_hnswlib.py @@ -0,0 +1,141 @@ +import asyncio +import os +from dataclasses import dataclass, field +from typing import Any +import pickle +import hnswlib +import numpy as np +import xxhash + +from .._utils import logger +from ..base import BaseVectorStorage + + +@dataclass +class HNSWVectorStorage(BaseVectorStorage): + ef_construction: int = 100 + M: int = 16 + max_elements: int = 1000000 + ef_search: int = 50 + num_threads: int = -1 + _index: Any = field(init=False) + _metadata: dict[str, dict] = field(default_factory=dict) + _current_elements: int = 0 + + def __post_init__(self): + self._index_file_name = os.path.join( + self.global_config["working_dir"], f"{self.namespace}_hnsw.index" + ) + self._metadata_file_name = os.path.join( + self.global_config["working_dir"], f"{self.namespace}_hnsw_metadata.pkl" + ) + self._embedding_batch_num = self.global_config.get("embedding_batch_num", 100) + + hnsw_params = self.global_config.get("vector_db_storage_cls_kwargs", {}) + self.ef_construction = hnsw_params.get("ef_construction", self.ef_construction) + self.M = hnsw_params.get("M", self.M) + self.max_elements = hnsw_params.get("max_elements", self.max_elements) + self.ef_search = hnsw_params.get("ef_search", self.ef_search) + self.num_threads = hnsw_params.get("num_threads", self.num_threads) + self._index = hnswlib.Index( + space="cosine", dim=self.embedding_func.embedding_dim + ) + + if os.path.exists(self._index_file_name) and os.path.exists( + self._metadata_file_name + ): + self._index.load_index( + self._index_file_name, max_elements=self.max_elements + ) + with open(self._metadata_file_name, "rb") as f: + self._metadata, self._current_elements = pickle.load(f) + logger.info( + f"Loaded existing index for {self.namespace} with {self._current_elements} elements" + ) + else: + self._index.init_index( + max_elements=self.max_elements, + ef_construction=self.ef_construction, + M=self.M, + ) + self._index.set_ef(self.ef_search) + self._metadata = {} + self._current_elements = 0 + logger.info(f"Created new index for {self.namespace}") + + async def upsert(self, data: dict[str, dict]) -> np.ndarray: + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not data: + logger.warning("You insert an empty data to vector DB") + return [] + + if self._current_elements + len(data) > self.max_elements: + raise ValueError( + f"Cannot insert {len(data)} elements. Current: {self._current_elements}, Max: {self.max_elements}" + ) + + list_data = [ + { + "id": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batch_size = min(self._embedding_batch_num, len(contents)) + embeddings = np.concatenate( + await asyncio.gather( + *[ + self.embedding_func(contents[i : i + batch_size]) + for i in range(0, len(contents), batch_size) + ] + ) + ) + + ids = np.fromiter( + (xxhash.xxh32_intdigest(d["id"].encode()) for d in list_data), + dtype=np.uint32, + count=len(list_data), + ) + self._metadata.update( + { + id_int: { + k: v for k, v in d.items() if k in self.meta_fields or k == "id" + } + for id_int, d in zip(ids, list_data) + } + ) + self._index.add_items(data=embeddings, ids=ids, num_threads=self.num_threads) + self._current_elements = self._index.get_current_count() + return ids + + async def query(self, query: str, top_k: int = 5) -> list[dict]: + if self._current_elements == 0: + return [] + + top_k = min(top_k, self._current_elements) + + if top_k > self.ef_search: + logger.warning( + f"Setting ef_search to {top_k} because top_k is larger than ef_search" + ) + self._index.set_ef(top_k) + + embedding = await self.embedding_func([query]) + labels, distances = self._index.knn_query( + data=embedding[0], k=top_k, num_threads=self.num_threads + ) + + return [ + { + **self._metadata.get(label, {}), + "distance": distance, + "similarity": 1 - distance, + } + for label, distance in zip(labels[0], distances[0]) + ] + + async def index_done_callback(self): + self._index.save_index(self._index_file_name) + with open(self._metadata_file_name, "wb") as f: + pickle.dump((self._metadata, self._current_elements), f) diff --git a/nano_graphrag/_storage/vdb_nanovectordb.py b/nano_graphrag/_storage/vdb_nanovectordb.py new file mode 100644 index 0000000..f73ab06 --- /dev/null +++ b/nano_graphrag/_storage/vdb_nanovectordb.py @@ -0,0 +1,68 @@ +import asyncio +import os +from dataclasses import dataclass +import numpy as np +from nano_vectordb import NanoVectorDB + +from .._utils import logger +from ..base import BaseVectorStorage + + +@dataclass +class NanoVectorDBStorage(BaseVectorStorage): + cosine_better_than_threshold: float = 0.2 + + def __post_init__(self): + + self._client_file_name = os.path.join( + self.global_config["working_dir"], f"vdb_{self.namespace}.json" + ) + self._max_batch_size = self.global_config["embedding_batch_num"] + self._client = NanoVectorDB( + self.embedding_func.embedding_dim, storage_file=self._client_file_name + ) + self.cosine_better_than_threshold = self.global_config.get( + "query_better_than_threshold", self.cosine_better_than_threshold + ) + + async def upsert(self, data: dict[str, dict]): + logger.info(f"Inserting {len(data)} vectors to {self.namespace}") + if not len(data): + logger.warning("You insert an empty data to vector DB") + return [] + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + embeddings_list = await asyncio.gather( + *[self.embedding_func(batch) for batch in batches] + ) + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + results = self._client.upsert(datas=list_data) + return results + + async def query(self, query: str, top_k=5): + embedding = await self.embedding_func([query]) + embedding = embedding[0] + results = self._client.query( + query=embedding, + top_k=top_k, + better_than_threshold=self.cosine_better_than_threshold, + ) + results = [ + {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results + ] + return results + + async def index_done_callback(self): + self._client.save() diff --git a/nano_graphrag/_utils.py b/nano_graphrag/_utils.py index 567394d..5185959 100644 --- a/nano_graphrag/_utils.py +++ b/nano_graphrag/_utils.py @@ -17,6 +17,18 @@ ENCODER = None +def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + try: + # If there is already an event loop, use it. + loop = asyncio.get_event_loop() + except RuntimeError: + # If in a sub-thread, create a new event loop. + logger.info("Creating a new event loop in a sub-thread.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + def locate_json_string_body_from_string(content: str) -> Union[str, None]: """Locate the JSON string body from a string""" maybe_json_str = re.search(r"{.*}", content, re.DOTALL) diff --git a/nano_graphrag/base.py b/nano_graphrag/base.py index ea9bb21..96fe225 100644 --- a/nano_graphrag/base.py +++ b/nano_graphrag/base.py @@ -61,6 +61,10 @@ class StorageNameSpace: namespace: str global_config: dict + async def index_start_callback(self): + """commit the storage operations after indexing""" + pass + async def index_done_callback(self): """commit the storage operations after indexing""" pass diff --git a/nano_graphrag/entity_extraction/extract.py b/nano_graphrag/entity_extraction/extract.py index d19c433..e72fa1c 100644 --- a/nano_graphrag/entity_extraction/extract.py +++ b/nano_graphrag/entity_extraction/extract.py @@ -4,7 +4,6 @@ from openai import BadRequestError from collections import defaultdict import dspy -from nano_graphrag._storage import BaseGraphStorage from nano_graphrag.base import ( BaseGraphStorage, BaseVectorStorage, @@ -20,34 +19,32 @@ async def generate_dataset( chunks: dict[str, TextChunkSchema], filepath: str, save_dataset: bool = True, - global_config: dict = {} + global_config: dict = {}, ) -> list[dspy.Example]: entity_extractor = TypedEntityRelationshipExtractor() if global_config.get("use_compiled_dspy_entity_relationship", False): entity_extractor.load(global_config["entity_relationship_module_path"]) - + ordered_chunks = list(chunks.items()) already_processed = 0 already_entities = 0 already_relations = 0 - async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> dspy.Example: + async def _process_single_content( + chunk_key_dp: tuple[str, TextChunkSchema] + ) -> dspy.Example: nonlocal already_processed, already_entities, already_relations chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] try: - prediction = await asyncio.to_thread( - entity_extractor, input_text=content - ) + prediction = await asyncio.to_thread(entity_extractor, input_text=content) entities, relationships = prediction.entities, prediction.relationships except BadRequestError as e: logger.error(f"Error in TypedEntityRelationshipExtractor: {e}") entities, relationships = [], [] example = dspy.Example( - input_text=content, - entities=entities, - relationships=relationships + input_text=content, entities=entities, relationships=relationships ).with_inputs("input_text") already_entities += len(entities) already_relations += len(relationships) @@ -65,12 +62,18 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]) -> examples = await asyncio.gather( *[_process_single_content(c) for c in ordered_chunks] ) - filtered_examples = [example for example in examples if len(example.entities) > 0 and len(example.relationships) > 0] + filtered_examples = [ + example + for example in examples + if len(example.entities) > 0 and len(example.relationships) > 0 + ] num_filtered_examples = len(examples) - len(filtered_examples) if save_dataset: - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: pickle.dump(filtered_examples, f) - logger.info(f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples") + logger.info( + f"Saved {len(filtered_examples)} examples with keys: {filtered_examples[0].keys()}, filtered {num_filtered_examples} examples" + ) return filtered_examples @@ -85,7 +88,7 @@ async def extract_entities_dspy( if global_config.get("use_compiled_dspy_entity_relationship", False): entity_extractor.load(global_config["entity_relationship_module_path"]) - + ordered_chunks = list(chunks.items()) already_processed = 0 already_entities = 0 @@ -97,25 +100,25 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] try: - prediction = await asyncio.to_thread( - entity_extractor, input_text=content - ) + prediction = await asyncio.to_thread(entity_extractor, input_text=content) entities, relationships = prediction.entities, prediction.relationships except BadRequestError as e: logger.error(f"Error in TypedEntityRelationshipExtractor: {e}") entities, relationships = [], [] - + maybe_nodes = defaultdict(list) maybe_edges = defaultdict(list) - + for entity in entities: entity["source_id"] = chunk_key - maybe_nodes[entity['entity_name']].append(entity) + maybe_nodes[entity["entity_name"]].append(entity) already_entities += 1 for relationship in relationships: relationship["source_id"] = chunk_key - maybe_edges[(relationship['src_id'], relationship['tgt_id'])].append(relationship) + maybe_edges[(relationship["src_id"], relationship["tgt_id"])].append( + relationship + ) already_relations += 1 already_processed += 1 diff --git a/nano_graphrag/entity_extraction/metric.py b/nano_graphrag/entity_extraction/metric.py index 0b41413..24df4f0 100644 --- a/nano_graphrag/entity_extraction/metric.py +++ b/nano_graphrag/entity_extraction/metric.py @@ -21,23 +21,42 @@ class AssessRelationships(dspy.Signature): - Balance the impact of matched and unmatched relationships in the final score. """ - gold_relationships: list[Relationship] = dspy.InputField(desc="The gold-standard relationships to compare against.") - predicted_relationships: list[Relationship] = dspy.InputField(desc="The predicted relationships to compare against the gold-standard relationships.") - similarity_score: float = dspy.OutputField(desc="Similarity score between 0 and 1, with 1 being the highest similarity.") - - -def relationships_similarity_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float: + gold_relationships: list[Relationship] = dspy.InputField( + desc="The gold-standard relationships to compare against." + ) + predicted_relationships: list[Relationship] = dspy.InputField( + desc="The predicted relationships to compare against the gold-standard relationships." + ) + similarity_score: float = dspy.OutputField( + desc="Similarity score between 0 and 1, with 1 being the highest similarity." + ) + + +def relationships_similarity_metric( + gold: dspy.Example, pred: dspy.Prediction, trace=None +) -> float: model = dspy.TypedChainOfThought(AssessRelationships) - gold_relationships = [Relationship(**item) for item in gold['relationships']] - predicted_relationships = [Relationship(**item) for item in pred['relationships']] - similarity_score = float(model(gold_relationships=gold_relationships, predicted_relationships=predicted_relationships).similarity_score) + gold_relationships = [Relationship(**item) for item in gold["relationships"]] + predicted_relationships = [Relationship(**item) for item in pred["relationships"]] + similarity_score = float( + model( + gold_relationships=gold_relationships, + predicted_relationships=predicted_relationships, + ).similarity_score + ) return similarity_score -def entity_recall_metric(gold: dspy.Example, pred: dspy.Prediction, trace=None) -> float: - true_set = set(item['entity_name'] for item in gold['entities']) - pred_set = set(item['entity_name'] for item in pred['entities']) +def entity_recall_metric( + gold: dspy.Example, pred: dspy.Prediction, trace=None +) -> float: + true_set = set(item["entity_name"] for item in gold["entities"]) + pred_set = set(item["entity_name"] for item in pred["entities"]) true_positives = len(pred_set.intersection(true_set)) false_negatives = len(true_set - pred_set) - recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 + recall = ( + true_positives / (true_positives + false_negatives) + if (true_positives + false_negatives) > 0 + else 0 + ) return recall diff --git a/nano_graphrag/entity_extraction/module.py b/nano_graphrag/entity_extraction/module.py index 70e86db..1ccce7f 100644 --- a/nano_graphrag/entity_extraction/module.py +++ b/nano_graphrag/entity_extraction/module.py @@ -9,42 +9,101 @@ https://github.com/SciPhi-AI/R2R/blob/6e958d1e451c1cb10b6fc868572659785d1091cb/r2r/providers/prompts/defaults.jsonl """ ENTITY_TYPES = [ - "PERSON", "ORGANIZATION", "LOCATION", "DATE", "TIME", "MONEY", - "PERCENTAGE", "PRODUCT", "EVENT", "LANGUAGE", "NATIONALITY", - "RELIGION", "TITLE", "PROFESSION", "ANIMAL", "PLANT", "DISEASE", - "MEDICATION", "CHEMICAL", "MATERIAL", "COLOR", "SHAPE", - "MEASUREMENT", "WEATHER", "NATURAL_DISASTER", "AWARD", "LAW", - "CRIME", "TECHNOLOGY", "SOFTWARE", "HARDWARE", "VEHICLE", - "FOOD", "DRINK", "SPORT", "MUSIC_GENRE", "INSTRUMENT", - "ARTWORK", "BOOK", "MOVIE", "TV_SHOW", "ACADEMIC_SUBJECT", - "SCIENTIFIC_THEORY", "POLITICAL_PARTY", "CURRENCY", - "STOCK_SYMBOL", "FILE_TYPE", "PROGRAMMING_LANGUAGE", - "MEDICAL_PROCEDURE", "CELESTIAL_BODY" + "PERSON", + "ORGANIZATION", + "LOCATION", + "DATE", + "TIME", + "MONEY", + "PERCENTAGE", + "PRODUCT", + "EVENT", + "LANGUAGE", + "NATIONALITY", + "RELIGION", + "TITLE", + "PROFESSION", + "ANIMAL", + "PLANT", + "DISEASE", + "MEDICATION", + "CHEMICAL", + "MATERIAL", + "COLOR", + "SHAPE", + "MEASUREMENT", + "WEATHER", + "NATURAL_DISASTER", + "AWARD", + "LAW", + "CRIME", + "TECHNOLOGY", + "SOFTWARE", + "HARDWARE", + "VEHICLE", + "FOOD", + "DRINK", + "SPORT", + "MUSIC_GENRE", + "INSTRUMENT", + "ARTWORK", + "BOOK", + "MOVIE", + "TV_SHOW", + "ACADEMIC_SUBJECT", + "SCIENTIFIC_THEORY", + "POLITICAL_PARTY", + "CURRENCY", + "STOCK_SYMBOL", + "FILE_TYPE", + "PROGRAMMING_LANGUAGE", + "MEDICAL_PROCEDURE", + "CELESTIAL_BODY", ] class Entity(BaseModel): entity_name: str = Field(..., description="The name of the entity.") entity_type: str = Field(..., description="The type of the entity.") - description: str = Field(..., description="The description of the entity, in details and comprehensive.") - importance_score: float = Field(..., ge=0, le=1, description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.") + description: str = Field( + ..., description="The description of the entity, in details and comprehensive." + ) + importance_score: float = Field( + ..., + ge=0, + le=1, + description="Importance score of the entity. Should be between 0 and 1 with 1 being the most important.", + ) class Relationship(BaseModel): src_id: str = Field(..., description="The name of the source entity.") tgt_id: str = Field(..., description="The name of the target entity.") - description: str = Field(..., description="The description of the relationship between the source and target entity, in details and comprehensive.") - weight: float = Field(..., ge=0, le=1, description="The weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.") - order: int = Field(..., ge=1, le=3, description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.") + description: str = Field( + ..., + description="The description of the relationship between the source and target entity, in details and comprehensive.", + ) + weight: float = Field( + ..., + ge=0, + le=1, + description="The weight of the relationship. Should be between 0 and 1 with 1 being the strongest relationship.", + ) + order: int = Field( + ..., + ge=1, + le=3, + description="The order of the relationship. 1 for direct relationships, 2 for second-order, 3 for third-order.", + ) class CombinedExtraction(dspy.Signature): """ - Given a text document that is potentially relevant to this activity and a list of entity types, + Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. - + Entity Guidelines: - 1. Each entity name should be an actual atomic word from the input text. + 1. Each entity name should be an actual atomic word from the input text. 2. Avoid duplicates and generic terms. 3. Make sure descriptions are detailed and comprehensive. Use multiple complete sentences for each point below: a). The entity's role or significance in the context @@ -52,9 +111,9 @@ class CombinedExtraction(dspy.Signature): c). Relationships to other entities (if applicable) d). Historical or cultural relevance (if applicable) e). Any notable actions or events associated with the entity - 4. All entity types from the text must be included. + 4. All entity types from the text must be included. 5. IMPORTANT: Only use entity types from the provided 'entity_types' list. Do not introduce new entity types. - + Relationship Guidelines: 1. Make sure relationship descriptions are detailed and comprehensive. Use multiple complete sentences for each point below: a). The nature of the relationship (e.g., familial, professional, causal) @@ -69,13 +128,23 @@ class CombinedExtraction(dspy.Signature): 3. The "src_id" and "tgt_id" fields must exactly match entity names from the extracted entities list. """ - input_text: str = dspy.InputField(desc="The text to extract entities and relationships from.") - entity_types: list[str] = dspy.InputField(desc="List of entity types used for extraction.") - entities_relationships: list[Union[Entity, Relationship]] = dspy.OutputField(desc="List of entities and relationships extracted from the text.") + input_text: str = dspy.InputField( + desc="The text to extract entities and relationships from." + ) + entity_types: list[str] = dspy.InputField( + desc="List of entity types used for extraction." + ) + entities_relationships: list[Union[Entity, Relationship]] = dspy.OutputField( + desc="List of entities and relationships extracted from the text." + ) class TypedEntityRelationshipExtractorException(dspy.Module): - def __init__(self, predictor: dspy.Module, exception_types: tuple[type[Exception]] = (Exception,)): + def __init__( + self, + predictor: dspy.Module, + exception_types: tuple[type[Exception]] = (Exception,), + ): super().__init__() self.predictor = predictor self.exception_types = exception_types @@ -96,29 +165,37 @@ def forward(self, **kwargs): class TypedEntityRelationshipExtractor(dspy.Module): - def __init__(self, lm: dspy.LM = None, reasoning: dspy.OutputField = None, max_retries: int = 3): + def __init__( + self, + lm: dspy.LM = None, + reasoning: dspy.OutputField = None, + max_retries: int = 3, + ): super().__init__() self.lm = lm self.entity_types = ENTITY_TYPES self.extractor = dspy.TypedChainOfThought( - signature=CombinedExtraction, - reasoning=reasoning, - max_retries=max_retries + signature=CombinedExtraction, reasoning=reasoning, max_retries=max_retries + ) + self.extractor = TypedEntityRelationshipExtractorException( + self.extractor, exception_types=(ValueError,) ) - self.extractor = TypedEntityRelationshipExtractorException(self.extractor, exception_types=(ValueError, )) def forward(self, input_text: str) -> dspy.Prediction: with dspy.context(lm=self.lm if self.lm is not None else dspy.settings.lm): - extraction_result = self.extractor(input_text=input_text, entity_types=self.entity_types) + extraction_result = self.extractor( + input_text=input_text, entity_types=self.entity_types + ) entities = [ dict( entity_name=clean_str(entity.entity_name.upper()), entity_type=clean_str(entity.entity_type.upper()), description=clean_str(entity.description), - importance_score=float(entity.importance_score) + importance_score=float(entity.importance_score), ) - for entity in extraction_result.entities_relationships if isinstance(entity, Entity) + for entity in extraction_result.entities_relationships + if isinstance(entity, Entity) ] relationships = [ @@ -127,9 +204,10 @@ def forward(self, input_text: str) -> dspy.Prediction: tgt_id=clean_str(relationship.tgt_id.upper()), description=clean_str(relationship.description), weight=float(relationship.weight), - order=int(relationship.order) + order=int(relationship.order), ) - for relationship in extraction_result.entities_relationships if isinstance(relationship, Relationship) + for relationship in extraction_result.entities_relationships + if isinstance(relationship, Relationship) ] - + return dspy.Prediction(entities=entities, relationships=relationships) diff --git a/nano_graphrag/graphrag.py b/nano_graphrag/graphrag.py index ce13270..ead9380 100644 --- a/nano_graphrag/graphrag.py +++ b/nano_graphrag/graphrag.py @@ -35,6 +35,7 @@ compute_mdhash_id, limit_async_func_call, convert_response_to_json, + always_get_an_event_loop, logger, ) from .base import ( @@ -46,18 +47,6 @@ ) -def always_get_an_event_loop() -> asyncio.AbstractEventLoop: - try: - # If there is already an event loop, use it. - loop = asyncio.get_event_loop() - except RuntimeError: - # If in a sub-thread, create a new event loop. - logger.info("Creating a new event loop in a sub-thread.") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - @dataclass class GraphRAG: working_dir: str = field( @@ -224,7 +213,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): if param.mode == "local" and not self.enable_local: raise ValueError("enable_local is False, cannot query in local mode") if param.mode == "naive" and not self.enable_naive_rag: - raise ValueError("enable_naive_rag is False, cannot query in local mode") + raise ValueError("enable_naive_rag is False, cannot query in naive mode") if param.mode == "local": response = await local_query( query, @@ -259,6 +248,7 @@ async def aquery(self, query: str, param: QueryParam = QueryParam()): return response async def ainsert(self, string_or_strings): + await self._insert_start() try: if isinstance(string_or_strings, str): string_or_strings = [string_or_strings] @@ -327,6 +317,16 @@ async def ainsert(self, string_or_strings): finally: await self._insert_done() + async def _insert_start(self): + tasks = [] + for storage_inst in [ + self.chunk_entity_relation_graph, + ]: + if storage_inst is None: + continue + tasks.append(cast(StorageNameSpace, storage_inst).index_start_callback()) + await asyncio.gather(*tasks) + async def _insert_done(self): tasks = [] for storage_inst in [ diff --git a/readme.md b/readme.md index 6a2550a..5e166da 100644 --- a/readme.md +++ b/readme.md @@ -42,7 +42,7 @@ 🎁 Excluding `tests` and prompts, `nano-graphrag` is about **800 lines of code**. -👌 Small yet [**portable**](#Components), [**asynchronous**](#Async) and fully typed. +👌 Small yet [**portable**](#Components)(faiss, neo4j, ollama...), [**asynchronous**](#Async) and fully typed. @@ -164,20 +164,22 @@ await graph_func.aquery(...) Below are the components you can use: -| Type | What | Where | -| :-------------- | :----------------------------------------------------------: | :------------------------------: | -| LLM | OpenAI | Built-in | -| | DeepSeek | [examples](./examples) | -| | `ollama` | [examples](./examples) | -| Embedding | OpenAI | Built-in | -| | Sentence-transformers | [examples](./examples) | -| Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in | -| | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) | -| | [`milvus-lite`](https://github.com/milvus-io/milvus-lite) | [examples](./examples) | -| | [faiss](https://github.com/facebookresearch/faiss?tab=readme-ov-file) | [examples](./examples) | -| Visualization | graphml | [examples](./examples) | -| Chunking | by token size | Built-in | -| | by text splitter | Built-in | +| Type | What | Where | +| :-------------- | :----------------------------------------------------------: | :-----------------------------------------------: | +| LLM | OpenAI | Built-in | +| | DeepSeek | [examples](./examples) | +| | `ollama` | [examples](./examples) | +| Embedding | OpenAI | Built-in | +| | Sentence-transformers | [examples](./examples) | +| Vector DataBase | [`nano-vectordb`](https://github.com/gusye1234/nano-vectordb) | Built-in | +| | [`hnswlib`](https://github.com/nmslib/hnswlib) | Built-in, [examples](./examples) | +| | [`milvus-lite`](https://github.com/milvus-io/milvus-lite) | [examples](./examples) | +| | [faiss](https://github.com/facebookresearch/faiss?tab=readme-ov-file) | [examples](./examples) | +| Graph Storage | [`networkx`](https://networkx.org/documentation/stable/index.html) | Built-in | +| | [`neo4j`](https://neo4j.com/) | Built-in([doc](./docs/use_neo4j_for_graphrag.md)) | +| Visualization | graphml | [examples](./examples) | +| Chunking | by token size | Built-in | +| | by text splitter | Built-in | - `Built-in` means we have that implementation inside `nano-graphrag`. `examples` means we have that implementation inside an tutorial under [examples](./examples) folder. diff --git a/requirements-dev.txt b/requirements-dev.txt index 28e1f07..b732211 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,4 @@ pytest future pytest-asyncio pytest-cov +python-dotenv \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 78ec190..be0e993 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ hnswlib xxhash tenacity dspy-ai +neo4j \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index ca0e42e..7464562 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,5 @@ import logging +import dotenv +dotenv.load_dotenv() logging.basicConfig(level=logging.INFO) diff --git a/tests/test_neo4j_storage.py b/tests/test_neo4j_storage.py new file mode 100644 index 0000000..8a87c3a --- /dev/null +++ b/tests/test_neo4j_storage.py @@ -0,0 +1,210 @@ +import os +import pytest +import numpy as np +from functools import wraps +from nano_graphrag import GraphRAG +from nano_graphrag._storage import Neo4jStorage +from nano_graphrag._utils import wrap_embedding_func_with_attrs + +if os.environ.get("NANO_GRAPHRAG_TEST_IGNORE_NEO4J", False): + pytest.skip("skipping neo4j tests", allow_module_level=True) + + +@pytest.fixture(scope="module") +def neo4j_config(): + return { + "neo4j_url": os.environ.get("NEO4J_URL", "bolt://localhost:7687"), + "neo4j_auth": ( + os.environ.get("NEO4J_USER", "neo4j"), + os.environ.get("NEO4J_PASSWORD", "neo4j"), + ), + } + + +@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192) +async def mock_embedding(texts: list[str]) -> np.ndarray: + return np.random.rand(len(texts), 384) + + +@pytest.fixture +def neo4j_storage(neo4j_config): + rag = GraphRAG( + working_dir="./tests/neo4j_test", + embedding_func=mock_embedding, + graph_storage_cls=Neo4jStorage, + addon_params=neo4j_config, + ) + storage = rag.chunk_entity_relation_graph + return storage + + +def reset_graph(func): + @wraps(func) + async def new_func(neo4j_storage): + await neo4j_storage._debug_delete_all_node_edges() + await neo4j_storage.index_start_callback() + results = await func(neo4j_storage) + await neo4j_storage._debug_delete_all_node_edges() + return results + + return new_func + + +def test_neo4j_storage_init(): + rag = GraphRAG( + working_dir="./tests/neo4j_test", + embedding_func=mock_embedding, + ) + with pytest.raises(ValueError): + storage = Neo4jStorage( + namespace="nanographrag_test", global_config=rag.__dict__ + ) + + +@pytest.mark.asyncio +@reset_graph +async def test_upsert_and_get_node(neo4j_storage): + node_id = "node1" + node_data = {"attr1": "value1", "attr2": "value2"} + return_data = {"id": node_id, "clusters": "[]", **node_data} + + await neo4j_storage.upsert_node(node_id, node_data) + + result = await neo4j_storage.get_node(node_id) + assert result == return_data + + has_node = await neo4j_storage.has_node(node_id) + assert has_node is True + + +@pytest.mark.asyncio +@reset_graph +async def test_upsert_and_get_edge(neo4j_storage): + source_id = "node1" + target_id = "node2" + edge_data = {"weight": 1.0, "type": "connection"} + + await neo4j_storage.upsert_node(source_id, {}) + await neo4j_storage.upsert_node(target_id, {}) + await neo4j_storage.upsert_edge(source_id, target_id, edge_data) + + result = await neo4j_storage.get_edge(source_id, target_id) + print(result) + assert result == edge_data + + has_edge = await neo4j_storage.has_edge(source_id, target_id) + assert has_edge is True + + +@pytest.mark.asyncio +@reset_graph +async def test_node_degree(neo4j_storage): + node_id = "center" + await neo4j_storage.upsert_node(node_id, {}) + + num_neighbors = 5 + for i in range(num_neighbors): + neighbor_id = f"neighbor{i}" + await neo4j_storage.upsert_node(neighbor_id, {}) + await neo4j_storage.upsert_edge(node_id, neighbor_id, {}) + + degree = await neo4j_storage.node_degree(node_id) + assert degree == num_neighbors + + +@pytest.mark.asyncio +@reset_graph +async def test_edge_degree(neo4j_storage): + source_id = "node1" + target_id = "node2" + + await neo4j_storage.upsert_node(source_id, {}) + await neo4j_storage.upsert_node(target_id, {}) + await neo4j_storage.upsert_edge(source_id, target_id, {}) + + num_source_neighbors = 3 + for i in range(num_source_neighbors): + neighbor_id = f"neighbor{i}" + await neo4j_storage.upsert_node(neighbor_id, {}) + await neo4j_storage.upsert_edge(source_id, neighbor_id, {}) + + num_target_neighbors = 2 + for i in range(num_target_neighbors): + neighbor_id = f"target_neighbor{i}" + await neo4j_storage.upsert_node(neighbor_id, {}) + await neo4j_storage.upsert_edge(target_id, neighbor_id, {}) + + expected_edge_degree = (num_source_neighbors + 1) + (num_target_neighbors + 1) + edge_degree = await neo4j_storage.edge_degree(source_id, target_id) + assert edge_degree == expected_edge_degree + + +@pytest.mark.asyncio +@reset_graph +async def test_get_node_edges(neo4j_storage): + center_id = "center" + await neo4j_storage.upsert_node(center_id, {}) + + expected_edges = [] + for i in range(3): + neighbor_id = f"neighbor{i}" + await neo4j_storage.upsert_node(neighbor_id, {}) + await neo4j_storage.upsert_edge(center_id, neighbor_id, {}) + expected_edges.append((center_id, neighbor_id)) + + result = await neo4j_storage.get_node_edges(center_id) + print(result) + assert set(result) == set(expected_edges) + + +@pytest.mark.asyncio +@reset_graph +async def test_leiden_clustering(neo4j_storage): + for i in range(10): + await neo4j_storage.upsert_node(f"NODE{i}", {"source_id": f"chunk{i}"}) + + for i in range(9): + await neo4j_storage.upsert_edge(f"NODE{i}", f"NODE{i+1}", {"weight": 1.0}) + + await neo4j_storage.clustering(algorithm="leiden") + + community_schema = await neo4j_storage.community_schema() + + assert len(community_schema) > 0 + + for community in community_schema.values(): + assert "level" in community + assert "title" in community + assert "edges" in community + assert "nodes" in community + assert "chunk_ids" in community + assert "occurrence" in community + assert "sub_communities" in community + print(community) + + +@pytest.mark.asyncio +@reset_graph +async def test_nonexistent_node_and_edge(neo4j_storage): + assert await neo4j_storage.has_node("nonexistent") is False + assert await neo4j_storage.has_edge("node1", "node2") is False + assert await neo4j_storage.get_node("nonexistent") is None + assert await neo4j_storage.get_edge("node1", "node2") is None + assert await neo4j_storage.get_node_edges("nonexistent") == [] + assert await neo4j_storage.node_degree("nonexistent") == 0 + assert await neo4j_storage.edge_degree("node1", "node2") == 0 + + +@pytest.mark.asyncio +@reset_graph +async def test_cluster_error_handling(neo4j_storage): + with pytest.raises( + ValueError, match="Clustering algorithm invalid_algo not supported" + ): + await neo4j_storage.clustering("invalid_algo") + + +@pytest.mark.asyncio +@reset_graph +async def test_index_done(neo4j_storage): + await neo4j_storage.index_done_callback() diff --git a/tests/test_openai.py b/tests/test_openai.py index afc5eea..9751aee 100644 --- a/tests/test_openai.py +++ b/tests/test_openai.py @@ -4,9 +4,23 @@ from nano_graphrag import _llm +def test_get_openai_async_client_instance(): + with patch("nano_graphrag._llm.AsyncOpenAI") as mock_openai: + mock_openai.return_value = "CLIENT" + client = _llm.get_openai_async_client_instance() + assert client == "CLIENT" + + +def test_get_azure_openai_async_client_instance(): + with patch("nano_graphrag._llm.AsyncAzureOpenAI") as mock_openai: + mock_openai.return_value = "AZURE_CLIENT" + client = _llm.get_azure_openai_async_client_instance() + assert client == "AZURE_CLIENT" + + @pytest.fixture def mock_openai_client(): - with patch("nano_graphrag._llm.AsyncOpenAI") as mock_openai: + with patch("nano_graphrag._llm.get_openai_async_client_instance") as mock_openai: mock_client = AsyncMock() mock_openai.return_value = mock_client yield mock_client @@ -14,7 +28,9 @@ def mock_openai_client(): @pytest.fixture def mock_azure_openai_client(): - with patch("nano_graphrag._llm.AsyncAzureOpenAI") as mock_openai: + with patch( + "nano_graphrag._llm.get_azure_openai_async_client_instance" + ) as mock_openai: mock_client = AsyncMock() mock_openai.return_value = mock_client yield mock_client @@ -37,7 +53,7 @@ async def test_openai_gpt4o(mock_openai_client): @pytest.mark.asyncio -async def test_openai_gpt4o_mini(mock_openai_client): +async def test_openai_gpt4omini(mock_openai_client): mock_response = AsyncMock() mock_response.choices = [Mock(message=Mock(content="1"))] messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}] @@ -69,7 +85,7 @@ async def test_azure_openai_gpt4o(mock_azure_openai_client): @pytest.mark.asyncio -async def test_azure_openai_gpt4o_mini(mock_azure_openai_client): +async def test_azure_openai_gpt4omini(mock_azure_openai_client): mock_response = AsyncMock() mock_response.choices = [Mock(message=Mock(content="1"))] messages = [{"role": "system", "content": "3"}, {"role": "user", "content": "2"}]