diff --git a/.semversioner/next-release/patch-20241212190223784600.json b/.semversioner/next-release/patch-20241212190223784600.json new file mode 100644 index 0000000000..d54d621cf7 --- /dev/null +++ b/.semversioner/next-release/patch-20241212190223784600.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Streamline flows." +} diff --git a/.semversioner/next-release/patch-20241213181544864279.json b/.semversioner/next-release/patch-20241213181544864279.json new file mode 100644 index 0000000000..17361b585a --- /dev/null +++ b/.semversioner/next-release/patch-20241213181544864279.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Move extractor code to co-locate with operations." +} diff --git a/graphrag/config/models/claim_extraction_config.py b/graphrag/config/models/claim_extraction_config.py index fcecd905d4..2192777400 100644 --- a/graphrag/config/models/claim_extraction_config.py +++ b/graphrag/config/models/claim_extraction_config.py @@ -37,12 +37,7 @@ class ClaimExtractionConfig(LLMConfig): def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict: """Get the resolved claim extraction strategy.""" - from graphrag.index.operations.extract_covariates import ( - ExtractClaimsStrategyType, - ) - return self.strategy or { - "type": ExtractClaimsStrategyType.graph_intelligence, "llm": self.llm.model_dump(), **self.parallelization.model_dump(), "extraction_prompt": (Path(root_dir) / self.prompt) diff --git a/graphrag/config/models/embed_graph_config.py b/graphrag/config/models/embed_graph_config.py index 12dd90cf4e..c4597e03ca 100644 --- a/graphrag/config/models/embed_graph_config.py +++ b/graphrag/config/models/embed_graph_config.py @@ -36,7 +36,7 @@ class EmbedGraphConfig(BaseModel): def resolved_strategy(self) -> dict: """Get the resolved node2vec strategy.""" - from graphrag.index.operations.embed_graph import ( + from graphrag.index.operations.embed_graph.typing import ( EmbedGraphStrategyType, ) diff --git a/graphrag/index/flows/compute_communities.py b/graphrag/index/flows/compute_communities.py index 09ec084ac6..6ca74ded4a 100644 --- a/graphrag/index/flows/compute_communities.py +++ b/graphrag/index/flows/compute_communities.py @@ -9,15 +9,11 @@ from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.create_graph import create_graph -from graphrag.index.operations.snapshot import snapshot -from graphrag.storage.pipeline_storage import PipelineStorage -async def compute_communities( +def compute_communities( base_relationship_edges: pd.DataFrame, - storage: PipelineStorage, clustering_strategy: dict[str, Any], - snapshot_transient_enabled: bool = False, ) -> pd.DataFrame: """All the steps to create the base entity graph.""" graph = create_graph(base_relationship_edges) @@ -32,12 +28,4 @@ async def compute_communities( ).explode("title") base_communities["community"] = base_communities["community"].astype(int) - if snapshot_transient_enabled: - await snapshot( - base_communities, - name="base_communities", - storage=storage, - formats=["parquet"], - ) - return base_communities diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index f699d507b7..3204425d11 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -15,18 +15,14 @@ ) from graphrag.index.operations.chunk_text import chunk_text -from graphrag.index.operations.snapshot import snapshot from graphrag.index.utils.hashing import gen_sha512_hash -from graphrag.storage.pipeline_storage import PipelineStorage -async def create_base_text_units( +def create_base_text_units( documents: pd.DataFrame, callbacks: VerbCallbacks, - storage: PipelineStorage, chunk_by_columns: list[str], chunk_strategy: dict[str, Any] | None = None, - snapshot_transient_enabled: bool = False, ) -> pd.DataFrame: """All the steps to transform base text_units.""" sort = documents.sort_values(by=["id"], ascending=[True]) @@ -74,19 +70,7 @@ async def create_base_text_units( # rename for downstream consumption chunked.rename(columns={"chunk": "text"}, inplace=True) - output = cast( - "pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True) - ) - - if snapshot_transient_enabled: - await snapshot( - output, - name="create_base_text_units", - storage=storage, - formats=["parquet"], - ) - - return output + return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)) # TODO: would be nice to inline this completely in the main method with pandas diff --git a/graphrag/index/flows/create_final_community_reports.py b/graphrag/index/flows/create_final_community_reports.py index 496ddd2778..574945de9d 100644 --- a/graphrag/index/flows/create_final_community_reports.py +++ b/graphrag/index/flows/create_final_community_reports.py @@ -12,7 +12,12 @@ ) from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors.community_reports.schemas import ( +from graphrag.index.operations.summarize_communities import ( + prepare_community_reports, + restore_community_hierarchy, + summarize_communities, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import ( CLAIM_DESCRIPTION, CLAIM_DETAILS, CLAIM_ID, @@ -32,11 +37,6 @@ NODE_ID, NODE_NAME, ) -from graphrag.index.operations.summarize_communities import ( - prepare_community_reports, - restore_community_hierarchy, - summarize_communities, -) async def create_final_community_reports( diff --git a/graphrag/index/flows/create_final_covariates.py b/graphrag/index/flows/create_final_covariates.py index ba400be56f..f9b5f7e377 100644 --- a/graphrag/index/flows/create_final_covariates.py +++ b/graphrag/index/flows/create_final_covariates.py @@ -13,7 +13,7 @@ ) from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.operations.extract_covariates import ( +from graphrag.index.operations.extract_covariates.extract_covariates import ( extract_covariates, ) diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index 1748294e36..0b6932e405 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -10,9 +10,10 @@ VerbCallbacks, ) +from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.create_graph import create_graph -from graphrag.index.operations.embed_graph import embed_graph -from graphrag.index.operations.layout_graph import layout_graph +from graphrag.index.operations.embed_graph.embed_graph import embed_graph +from graphrag.index.operations.layout_graph.layout_graph import layout_graph def create_final_nodes( @@ -37,15 +38,19 @@ def create_final_nodes( layout_strategy, embeddings=graph_embeddings, ) - nodes = base_entity_nodes.merge( - layout, left_on="title", right_on="label", how="left" - ) - joined = nodes.merge(base_communities, on="title", how="left") - joined["level"] = joined["level"].fillna(0).astype(int) - joined["community"] = joined["community"].fillna(-1).astype(int) + degrees = compute_degree(graph) - return joined.loc[ + nodes = ( + base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left") + .merge(degrees, on="title", how="left") + .merge(base_communities, on="title", how="left") + ) + nodes["level"] = nodes["level"].fillna(0).astype(int) + nodes["community"] = nodes["community"].fillna(-1).astype(int) + # disconnected nodes and those with no community even at level 0 can be missing degree + nodes["degree"] = nodes["degree"].fillna(0).astype(int) + return nodes.loc[ :, [ "id", diff --git a/graphrag/index/flows/create_final_relationships.py b/graphrag/index/flows/create_final_relationships.py index 5f58d1ffbe..03f6e362ea 100644 --- a/graphrag/index/flows/create_final_relationships.py +++ b/graphrag/index/flows/create_final_relationships.py @@ -5,20 +5,25 @@ import pandas as pd +from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.compute_edge_combined_degree import ( compute_edge_combined_degree, ) +from graphrag.index.operations.create_graph import create_graph def create_final_relationships( base_relationship_edges: pd.DataFrame, - base_entity_nodes: pd.DataFrame, ) -> pd.DataFrame: """All the steps to transform final relationships.""" relationships = base_relationship_edges + + graph = create_graph(base_relationship_edges) + degrees = compute_degree(graph) + relationships["combined_degree"] = compute_edge_combined_degree( relationships, - base_entity_nodes, + degrees, node_name_column="title", node_degree_column="degree", edge_source_column="source", diff --git a/graphrag/index/flows/extract_graph.py b/graphrag/index/flows/extract_graph.py index db73635b85..87e369f525 100644 --- a/graphrag/index/flows/extract_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -6,7 +6,6 @@ from typing import Any from uuid import uuid4 -import networkx as nx import pandas as pd from datashaper import ( AsyncType, @@ -14,33 +13,26 @@ ) from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.extract_entities import extract_entities -from graphrag.index.operations.snapshot import snapshot -from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.index.operations.summarize_descriptions import ( summarize_descriptions, ) -from graphrag.storage.pipeline_storage import PipelineStorage async def extract_graph( text_units: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - storage: PipelineStorage, extraction_strategy: dict[str, Any] | None = None, extraction_num_threads: int = 4, extraction_async_mode: AsyncType = AsyncType.AsyncIO, entity_types: list[str] | None = None, summarization_strategy: dict[str, Any] | None = None, summarization_num_threads: int = 4, - snapshot_graphml_enabled: bool = False, - snapshot_transient_enabled: bool = False, ) -> tuple[pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later - entity_dfs, relationship_dfs = await extract_entities( + entities, relationships = await extract_entities( text_units, callbacks, cache, @@ -52,87 +44,38 @@ async def extract_graph( num_threads=extraction_num_threads, ) - if not _validate_data(entity_dfs): + if not _validate_data(entities): error_msg = "Entity Extraction failed. No entities detected during extraction." callbacks.error(error_msg) raise ValueError(error_msg) - if not _validate_data(relationship_dfs): + if not _validate_data(relationships): error_msg = ( "Entity Extraction failed. No relationships detected during extraction." ) callbacks.error(error_msg) raise ValueError(error_msg) - merged_entities = _merge_entities(entity_dfs) - merged_relationships = _merge_relationships(relationship_dfs) - entity_summaries, relationship_summaries = await summarize_descriptions( - merged_entities, - merged_relationships, + entities, + relationships, callbacks, cache, strategy=summarization_strategy, num_threads=summarization_num_threads, ) - base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries) - - graph = create_graph(base_relationship_edges) - - base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph) - - if snapshot_graphml_enabled: - # todo: extract graphs at each level, and add in meta like descriptions - await snapshot_graphml( - graph, - name="graph", - storage=storage, - ) + base_relationship_edges = _prep_edges(relationships, relationship_summaries) - if snapshot_transient_enabled: - await snapshot( - base_entity_nodes, - name="base_entity_nodes", - storage=storage, - formats=["parquet"], - ) - await snapshot( - base_relationship_edges, - name="base_relationship_edges", - storage=storage, - formats=["parquet"], - ) + base_entity_nodes = _prep_nodes(entities, entity_summaries) return (base_entity_nodes, base_relationship_edges) -def _merge_entities(entity_dfs) -> pd.DataFrame: - all_entities = pd.concat(entity_dfs, ignore_index=True) - return ( - all_entities.groupby(["name", "type"], sort=False) - .agg({"description": list, "source_id": list}) - .reset_index() - ) - - -def _merge_relationships(relationship_dfs) -> pd.DataFrame: - all_relationships = pd.concat(relationship_dfs, ignore_index=False) - return ( - all_relationships.groupby(["source", "target"], sort=False) - .agg({"description": list, "source_id": list, "weight": "sum"}) - .reset_index() - ) - - -def _prep_nodes(entities, summaries, graph) -> pd.DataFrame: - degrees_df = _compute_degree(graph) +def _prep_nodes(entities, summaries) -> pd.DataFrame: entities.drop(columns=["description"], inplace=True) - nodes = ( - entities.merge(summaries, on="name", how="left") - .merge(degrees_df, on="name") - .drop_duplicates(subset="name") - .rename(columns={"name": "title", "source_id": "text_unit_ids"}) + nodes = entities.merge(summaries, on="title", how="left").drop_duplicates( + subset="title" ) nodes = nodes.loc[nodes["title"].notna()].reset_index() nodes["human_readable_id"] = nodes.index @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame: relationships.drop(columns=["description"]) .drop_duplicates(subset=["source", "target"]) .merge(summaries, on=["source", "target"], how="left") - .rename(columns={"source_id": "text_unit_ids"}) ) edges["human_readable_id"] = edges.index edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4())) return edges -def _compute_degree(graph: nx.Graph) -> pd.DataFrame: - return pd.DataFrame([ - {"name": node, "degree": int(degree)} - for node, degree in graph.degree # type: ignore - ]) - - -def _validate_data(df_list: list[pd.DataFrame]) -> bool: - """Validate that the dataframe list is valid. At least one dataframe must contain data.""" - return any( - len(df) > 0 for df in df_list - ) # Check for len, not .empty, as the dfs have schemas in some cases +def _validate_data(df: pd.DataFrame) -> bool: + """Validate that the dataframe has data.""" + return len(df) > 0 diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index f701c464f2..877966dab7 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings( strategy=text_embed_config["strategy"], ) - data = data.loc[:, ["id", "embedding"]] - if snapshot_embeddings_enabled is True: + data = data.loc[:, ["id", "embedding"]] await snapshot( data, name=f"embeddings.{name}", diff --git a/graphrag/index/graph/__init__.py b/graphrag/index/graph/__init__.py deleted file mode 100644 index cb26e59595..0000000000 --- a/graphrag/index/graph/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph package root.""" diff --git a/graphrag/index/graph/embedding/__init__.py b/graphrag/index/graph/embedding/__init__.py deleted file mode 100644 index ff075875a5..0000000000 --- a/graphrag/index/graph/embedding/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph embedding package root.""" - -from graphrag.index.graph.embedding.embedding import NodeEmbeddings, embed_nod2vec - -__all__ = ["NodeEmbeddings", "embed_nod2vec"] diff --git a/graphrag/index/graph/extractors/__init__.py b/graphrag/index/graph/extractors/__init__.py deleted file mode 100644 index 42ad16b89c..0000000000 --- a/graphrag/index/graph/extractors/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph extractors package root.""" - -from graphrag.index.graph.extractors.claims import ClaimExtractor -from graphrag.index.graph.extractors.community_reports import ( - CommunityReportsExtractor, -) -from graphrag.index.graph.extractors.graph import GraphExtractionResult, GraphExtractor - -__all__ = [ - "ClaimExtractor", - "CommunityReportsExtractor", - "GraphExtractionResult", - "GraphExtractor", -] diff --git a/graphrag/index/graph/extractors/claims/__init__.py b/graphrag/index/graph/extractors/claims/__init__.py deleted file mode 100644 index 897cdd1125..0000000000 --- a/graphrag/index/graph/extractors/claims/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph extractors claims package root.""" - -from graphrag.index.graph.extractors.claims.claim_extractor import ClaimExtractor - -__all__ = ["ClaimExtractor"] diff --git a/graphrag/index/graph/extractors/community_reports/__init__.py b/graphrag/index/graph/extractors/community_reports/__init__.py deleted file mode 100644 index 63d75de87e..0000000000 --- a/graphrag/index/graph/extractors/community_reports/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine community reports package root.""" - -import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.graph.extractors.community_reports.build_mixed_context import ( - build_mixed_context, -) -from graphrag.index.graph.extractors.community_reports.community_reports_extractor import ( - CommunityReportsExtractor, -) -from graphrag.index.graph.extractors.community_reports.prep_community_report_context import ( - prep_community_report_context, -) -from graphrag.index.graph.extractors.community_reports.sort_context import sort_context - -__all__ = [ - "CommunityReportsExtractor", - "build_mixed_context", - "prep_community_report_context", - "schemas", - "sort_context", -] diff --git a/graphrag/index/graph/extractors/graph/__init__.py b/graphrag/index/graph/extractors/graph/__init__.py deleted file mode 100644 index c3f14bfa2f..0000000000 --- a/graphrag/index/graph/extractors/graph/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine unipartite graph package root.""" - -from graphrag.index.graph.extractors.graph.graph_extractor import ( - DEFAULT_ENTITY_TYPES, - GraphExtractionResult, - GraphExtractor, -) - -__all__ = [ - "DEFAULT_ENTITY_TYPES", - "GraphExtractionResult", - "GraphExtractor", -] diff --git a/graphrag/index/graph/extractors/summarize/__init__.py b/graphrag/index/graph/extractors/summarize/__init__.py deleted file mode 100644 index 54661d0f1c..0000000000 --- a/graphrag/index/graph/extractors/summarize/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine unipartite graph package root.""" - -from graphrag.index.graph.extractors.summarize.description_summary_extractor import ( - SummarizationResult, - SummarizeExtractor, -) - -__all__ = ["SummarizationResult", "SummarizeExtractor"] diff --git a/graphrag/index/graph/utils/__init__.py b/graphrag/index/graph/utils/__init__.py deleted file mode 100644 index 2f6971186d..0000000000 --- a/graphrag/index/graph/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph utils package root.""" - -from graphrag.index.graph.utils.normalize_node_names import normalize_node_names -from graphrag.index.graph.utils.stable_lcc import stable_largest_connected_component - -__all__ = ["normalize_node_names", "stable_largest_connected_component"] diff --git a/graphrag/index/graph/utils/normalize_node_names.py b/graphrag/index/graph/utils/normalize_node_names.py deleted file mode 100644 index bcc874a927..0000000000 --- a/graphrag/index/graph/utils/normalize_node_names.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing normalize_node_names method definition.""" - -import html - -import networkx as nx - - -def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: - """Normalize node names.""" - node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore - return nx.relabel_nodes(graph, node_mapping) diff --git a/graphrag/index/graph/visualization/__init__.py b/graphrag/index/graph/visualization/__init__.py deleted file mode 100644 index 090acdec32..0000000000 --- a/graphrag/index/graph/visualization/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine graph visualization package root.""" - -from graphrag.index.graph.visualization.compute_umap_positions import ( - compute_umap_positions, - get_zero_positions, -) -from graphrag.index.graph.visualization.typing import GraphLayout, NodePosition - -__all__ = [ - "GraphLayout", - "NodePosition", - "compute_umap_positions", - "get_zero_positions", -] diff --git a/graphrag/index/graph/visualization/compute_umap_positions.py b/graphrag/index/graph/visualization/compute_umap_positions.py deleted file mode 100644 index 36ac354b72..0000000000 --- a/graphrag/index/graph/visualization/compute_umap_positions.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing compute_umap_positions and visualize_embedding method definition.""" - -import matplotlib.pyplot as plt -import networkx as nx -import numpy as np - -from graphrag.index.graph.visualization.typing import NodePosition - - -def get_zero_positions( - node_labels: list[str], - node_categories: list[int] | None = None, - node_sizes: list[int] | None = None, - three_d: bool | None = False, -) -> list[NodePosition]: - """Project embedding vectors down to 2D/3D using UMAP.""" - embedding_position_data: list[NodePosition] = [] - for index, node_name in enumerate(node_labels): - node_category = 1 if node_categories is None else node_categories[index] - node_size = 1 if node_sizes is None else node_sizes[index] - - if not three_d: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=0, - y=0, - cluster=str(int(node_category)), - size=int(node_size), - ) - ) - else: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=0, - y=0, - z=0, - cluster=str(int(node_category)), - size=int(node_size), - ) - ) - return embedding_position_data - - -def compute_umap_positions( - embedding_vectors: np.ndarray, - node_labels: list[str], - node_categories: list[int] | None = None, - node_sizes: list[int] | None = None, - min_dist: float = 0.75, - n_neighbors: int = 25, - spread: int = 1, - metric: str = "euclidean", - n_components: int = 2, - random_state: int = 86, -) -> list[NodePosition]: - """Project embedding vectors down to 2D/3D using UMAP.""" - # NOTE: This import is done here to reduce the initial import time of the graphrag package - import umap - - embedding_positions = umap.UMAP( - min_dist=min_dist, - n_neighbors=n_neighbors, - spread=spread, - n_components=n_components, - metric=metric, - random_state=random_state, - ).fit_transform(embedding_vectors) - - embedding_position_data: list[NodePosition] = [] - for index, node_name in enumerate(node_labels): - node_points = embedding_positions[index] # type: ignore - node_category = 1 if node_categories is None else node_categories[index] - node_size = 1 if node_sizes is None else node_sizes[index] - - if len(node_points) == 2: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=float(node_points[0]), - y=float(node_points[1]), - cluster=str(int(node_category)), - size=int(node_size), - ) - ) - else: - embedding_position_data.append( - NodePosition( - label=str(node_name), - x=float(node_points[0]), - y=float(node_points[1]), - z=float(node_points[2]), - cluster=str(int(node_category)), - size=int(node_size), - ) - ) - return embedding_position_data - - -def visualize_embedding( - graph, - umap_positions: list[dict], -): - """Project embedding down to 2D using UMAP and visualize.""" - # NOTE: This import is done here to reduce the initial import time of the graphrag package - import graspologic as gc - - # rendering - plt.clf() - figure = plt.gcf() - ax = plt.gca() - - ax.set_axis_off() - figure.set_size_inches(10, 10) - figure.set_dpi(400) - - node_position_dict = { - (str)(position["label"]): (position["x"], position["y"]) - for position in umap_positions - } - node_category_dict = { - (str)(position["label"]): position["category"] for position in umap_positions - } - node_sizes = [position["size"] for position in umap_positions] - node_colors = gc.layouts.categorical_colors(node_category_dict) # type: ignore - - vertices = [] - node_color_list = [] - for node in node_position_dict: - vertices.append(node) - node_color_list.append(node_colors[node]) - - nx.draw_networkx_nodes( - graph, - pos=node_position_dict, - nodelist=vertices, - node_color=node_color_list, # type: ignore - alpha=1.0, - linewidths=0.01, - node_size=node_sizes, # type: ignore - node_shape="o", - ax=ax, - ) - plt.show() diff --git a/graphrag/index/operations/cluster_graph.py b/graphrag/index/operations/cluster_graph.py index c773587baa..07f5998f9f 100644 --- a/graphrag/index/operations/cluster_graph.py +++ b/graphrag/index/operations/cluster_graph.py @@ -9,7 +9,7 @@ import networkx as nx -from graphrag.index.graph.utils import stable_largest_connected_component +from graphrag.index.utils.stable_lcc import stable_largest_connected_component Communities = list[tuple[int, int, int, list[str]]] diff --git a/graphrag/index/operations/compute_degree.py b/graphrag/index/operations/compute_degree.py new file mode 100644 index 0000000000..b720bf6de5 --- /dev/null +++ b/graphrag/index/operations/compute_degree.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph definition.""" + +import networkx as nx +import pandas as pd + + +def compute_degree(graph: nx.Graph) -> pd.DataFrame: + """Create a new DataFrame with the degree of each node in the graph.""" + return pd.DataFrame([ + {"title": node, "degree": int(degree)} + for node, degree in graph.degree # type: ignore + ]) diff --git a/graphrag/index/operations/embed_graph/__init__.py b/graphrag/index/operations/embed_graph/__init__.py index 07c91c3be8..3c3e4b1ca4 100644 --- a/graphrag/index/operations/embed_graph/__init__.py +++ b/graphrag/index/operations/embed_graph/__init__.py @@ -2,11 +2,3 @@ # Licensed under the MIT License """The Indexing Engine graph embed package root.""" - -from graphrag.index.operations.embed_graph.embed_graph import ( - EmbedGraphStrategyType, - embed_graph, -) -from graphrag.index.operations.embed_graph.typing import NodeEmbeddings - -__all__ = ["EmbedGraphStrategyType", "NodeEmbeddings", "embed_graph"] diff --git a/graphrag/index/operations/embed_graph/embed_graph.py b/graphrag/index/operations/embed_graph/embed_graph.py index b09098c22b..6e161db7dc 100644 --- a/graphrag/index/operations/embed_graph/embed_graph.py +++ b/graphrag/index/operations/embed_graph/embed_graph.py @@ -3,25 +3,17 @@ """A module containing embed_graph and run_embeddings methods definition.""" -from enum import Enum from typing import Any import networkx as nx -from graphrag.index.graph.embedding import embed_nod2vec -from graphrag.index.graph.utils import stable_largest_connected_component -from graphrag.index.operations.embed_graph.typing import NodeEmbeddings +from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec +from graphrag.index.operations.embed_graph.typing import ( + EmbedGraphStrategyType, + NodeEmbeddings, +) from graphrag.index.utils.load_graph import load_graph - - -class EmbedGraphStrategyType(str, Enum): - """EmbedGraphStrategyType class definition.""" - - node2vec = "node2vec" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' +from graphrag.index.utils.stable_lcc import stable_largest_connected_component def embed_graph( @@ -81,7 +73,7 @@ def run_node_2_vec(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: graph = stable_largest_connected_component(graph) # create graph embedding using node2vec - embeddings = embed_nod2vec( + embeddings = embed_node2vec( graph=graph, dimensions=args.get("dimensions", 1536), num_walks=args.get("num_walks", 10), diff --git a/graphrag/index/graph/embedding/embedding.py b/graphrag/index/operations/embed_graph/embed_node2vec.py similarity index 98% rename from graphrag/index/graph/embedding/embedding.py rename to graphrag/index/operations/embed_graph/embed_node2vec.py index ff6f86e72a..a009c670f6 100644 --- a/graphrag/index/graph/embedding/embedding.py +++ b/graphrag/index/operations/embed_graph/embed_node2vec.py @@ -17,7 +17,7 @@ class NodeEmbeddings: embeddings: np.ndarray -def embed_nod2vec( +def embed_node2vec( graph: nx.Graph | nx.DiGraph, dimensions: int = 1536, num_walks: int = 10, diff --git a/graphrag/index/operations/embed_graph/typing.py b/graphrag/index/operations/embed_graph/typing.py index fea792c9b1..618806eaed 100644 --- a/graphrag/index/operations/embed_graph/typing.py +++ b/graphrag/index/operations/embed_graph/typing.py @@ -4,8 +4,20 @@ """A module containing different lists and dictionaries.""" # Use this for now instead of a wrapper +from enum import Enum from typing import Any + +class EmbedGraphStrategyType(str, Enum): + """EmbedGraphStrategyType class definition.""" + + node2vec = "node2vec" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + NodeList = list[str] EmbeddingList = list[Any] NodeEmbeddings = dict[str, list[float]] diff --git a/graphrag/index/operations/extract_covariates/__init__.py b/graphrag/index/operations/extract_covariates/__init__.py index 315f503c3e..e130668477 100644 --- a/graphrag/index/operations/extract_covariates/__init__.py +++ b/graphrag/index/operations/extract_covariates/__init__.py @@ -2,10 +2,3 @@ # Licensed under the MIT License """The Indexing Engine text extract claims package root.""" - -from graphrag.index.operations.extract_covariates.extract_covariates import ( - ExtractClaimsStrategyType, - extract_covariates, -) - -__all__ = ["ExtractClaimsStrategyType", "extract_covariates"] diff --git a/graphrag/index/graph/extractors/claims/claim_extractor.py b/graphrag/index/operations/extract_covariates/claim_extractor.py similarity index 100% rename from graphrag/index/graph/extractors/claims/claim_extractor.py rename to graphrag/index/operations/extract_covariates/claim_extractor.py diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index ff78f949fa..5dab42b8df 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -4,6 +4,7 @@ """A module containing the extract_covariates verb definition.""" import logging +from collections.abc import Iterable from dataclasses import asdict from typing import Any @@ -14,11 +15,13 @@ derive_from_rows, ) +import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.index.llm.load_llm import load_llm, read_llm_params +from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor from graphrag.index.operations.extract_covariates.typing import ( Covariate, - CovariateExtractStrategy, - ExtractClaimsStrategyType, + CovariateExtractionResult, ) log = logging.getLogger(__name__) @@ -46,14 +49,11 @@ async def extract_covariates( resolved_entities_map = {} strategy = strategy or {} - strategy_exec = load_strategy( - strategy.get("type", ExtractClaimsStrategyType.graph_intelligence) - ) strategy_config = {**strategy} async def run_strategy(row): text = row[column] - result = await strategy_exec( + result = await run_claim_extraction( text, entity_types, resolved_entities_map, callbacks, cache, strategy_config ) return [ @@ -71,20 +71,71 @@ async def run_strategy(row): return pd.DataFrame([item for row in results for item in row or []]) -def load_strategy(strategy_type: ExtractClaimsStrategyType) -> CovariateExtractStrategy: - """Load strategy method definition.""" - match strategy_type: - case ExtractClaimsStrategyType.graph_intelligence: - from graphrag.index.operations.extract_covariates.strategies import ( - run_graph_intelligence, - ) - - return run_graph_intelligence - case _: - msg = f"Unknown strategy: {strategy_type}" - raise ValueError(msg) - - def create_row_from_claim_data(row, covariate_data: Covariate, covariate_type: str): """Create a row from the claim data and the input row.""" return {**row, **asdict(covariate_data), "covariate_type": covariate_type} + + +async def run_claim_extraction( + input: str | Iterable[str], + entity_types: list[str], + resolved_entities_map: dict[str, str], + callbacks: VerbCallbacks, + cache: PipelineCache, + strategy_config: dict[str, Any], +) -> CovariateExtractionResult: + """Run the Claim extraction chain.""" + llm_config = read_llm_params(strategy_config.get("llm", {})) + llm = load_llm("claim_extraction", llm_config, callbacks=callbacks, cache=cache) + extraction_prompt = strategy_config.get("extraction_prompt") + max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS) + tuple_delimiter = strategy_config.get("tuple_delimiter") + record_delimiter = strategy_config.get("record_delimiter") + completion_delimiter = strategy_config.get("completion_delimiter") + encoding_model = strategy_config.get("encoding_name") + + extractor = ClaimExtractor( + llm_invoker=llm, + extraction_prompt=extraction_prompt, + max_gleanings=max_gleanings, + encoding_model=encoding_model, + on_error=lambda e, s, d: ( + callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None + ), + ) + + claim_description = strategy_config.get("claim_description") + if claim_description is None: + msg = "claim_description is required for claim extraction" + raise ValueError(msg) + + input = [input] if isinstance(input, str) else input + + results = await extractor({ + "input_text": input, + "entity_specs": entity_types, + "resolved_entities": resolved_entities_map, + "claim_description": claim_description, + "tuple_delimiter": tuple_delimiter, + "record_delimiter": record_delimiter, + "completion_delimiter": completion_delimiter, + }) + + claim_data = results.output + return CovariateExtractionResult([create_covariate(item) for item in claim_data]) + + +def create_covariate(item: dict[str, Any]) -> Covariate: + """Create a covariate from the item.""" + return Covariate( + subject_id=item.get("subject_id"), + object_id=item.get("object_id"), + type=item.get("type"), + status=item.get("status"), + start_date=item.get("start_date"), + end_date=item.get("end_date"), + description=item.get("description"), + source_text=item.get("source_text"), + record_id=item.get("record_id"), + id=item.get("id"), + ) diff --git a/graphrag/index/operations/extract_covariates/strategies.py b/graphrag/index/operations/extract_covariates/strategies.py deleted file mode 100644 index edf4c3a670..0000000000 --- a/graphrag/index/operations/extract_covariates/strategies.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing run and _run_chain methods definitions.""" - -from collections.abc import Iterable -from typing import Any - -from datashaper import VerbCallbacks -from fnllm import ChatLLM - -import graphrag.config.defaults as defs -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors.claims import ClaimExtractor -from graphrag.index.llm.load_llm import load_llm, read_llm_params -from graphrag.index.operations.extract_covariates.typing import ( - Covariate, - CovariateExtractionResult, -) - - -async def run_graph_intelligence( - input: str | Iterable[str], - entity_types: list[str], - resolved_entities_map: dict[str, str], - callbacks: VerbCallbacks, - cache: PipelineCache, - strategy_config: dict[str, Any], -) -> CovariateExtractionResult: - """Run the Claim extraction chain.""" - llm_config = read_llm_params(strategy_config.get("llm", {})) - llm = load_llm("claim_extraction", llm_config, callbacks=callbacks, cache=cache) - return await _execute( - llm, input, entity_types, resolved_entities_map, callbacks, strategy_config - ) - - -async def _execute( - llm: ChatLLM, - texts: Iterable[str], - entity_types: list[str], - resolved_entities_map: dict[str, str], - callbacks: VerbCallbacks, - strategy_config: dict[str, Any], -) -> CovariateExtractionResult: - extraction_prompt = strategy_config.get("extraction_prompt") - max_gleanings = strategy_config.get("max_gleanings", defs.CLAIM_MAX_GLEANINGS) - tuple_delimiter = strategy_config.get("tuple_delimiter") - record_delimiter = strategy_config.get("record_delimiter") - completion_delimiter = strategy_config.get("completion_delimiter") - encoding_model = strategy_config.get("encoding_name") - - extractor = ClaimExtractor( - llm_invoker=llm, - extraction_prompt=extraction_prompt, - max_gleanings=max_gleanings, - encoding_model=encoding_model, - on_error=lambda e, s, d: ( - callbacks.error("Claim Extraction Error", e, s, d) if callbacks else None - ), - ) - - claim_description = strategy_config.get("claim_description") - if claim_description is None: - msg = "claim_description is required for claim extraction" - raise ValueError(msg) - - texts = [texts] if isinstance(texts, str) else texts - - results = await extractor({ - "input_text": texts, - "entity_specs": entity_types, - "resolved_entities": resolved_entities_map, - "claim_description": claim_description, - "tuple_delimiter": tuple_delimiter, - "record_delimiter": record_delimiter, - "completion_delimiter": completion_delimiter, - }) - - claim_data = results.output - return CovariateExtractionResult([create_covariate(item) for item in claim_data]) - - -def create_covariate(item: dict[str, Any]) -> Covariate: - """Create a covariate from the item.""" - return Covariate( - subject_id=item.get("subject_id"), - object_id=item.get("object_id"), - type=item.get("type"), - status=item.get("status"), - start_date=item.get("start_date"), - end_date=item.get("end_date"), - description=item.get("description"), - source_text=item.get("source_text"), - record_id=item.get("record_id"), - id=item.get("id"), - ) diff --git a/graphrag/index/operations/extract_covariates/typing.py b/graphrag/index/operations/extract_covariates/typing.py index 1a0a857a6f..f5c7e0a02e 100644 --- a/graphrag/index/operations/extract_covariates/typing.py +++ b/graphrag/index/operations/extract_covariates/typing.py @@ -5,7 +5,6 @@ from collections.abc import Awaitable, Callable, Iterable from dataclasses import dataclass -from enum import Enum from typing import Any from datashaper import VerbCallbacks @@ -49,13 +48,3 @@ class CovariateExtractionResult: ], Awaitable[CovariateExtractionResult], ] - - -class ExtractClaimsStrategyType(str, Enum): - """ExtractClaimsStrategyType class definition.""" - - graph_intelligence = "graph_intelligence" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index 6b3e90ce39..3245c2481c 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -4,7 +4,6 @@ """A module containing entity_extract methods.""" import logging -from enum import Enum from typing import Any import pandas as pd @@ -16,26 +15,15 @@ from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.bootstrap import bootstrap -from graphrag.index.operations.extract_entities.strategies.typing import ( +from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractStrategy, + ExtractEntityStrategyType, ) log = logging.getLogger(__name__) -class ExtractEntityStrategyType(str, Enum): - """ExtractEntityStrategyType class definition.""" - - graph_intelligence = "graph_intelligence" - graph_intelligence_json = "graph_intelligence_json" - nltk = "nltk" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] @@ -49,7 +37,7 @@ async def extract_entities( async_mode: AsyncType = AsyncType.AsyncIO, entity_types=DEFAULT_ENTITY_TYPES, num_threads: int = 4, -) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]: +) -> tuple[pd.DataFrame, pd.DataFrame]: """ Extract entities from a piece of text. @@ -150,14 +138,17 @@ async def run_strategy(row): entity_dfs.append(pd.DataFrame(result[0])) relationship_dfs.append(pd.DataFrame(result[1])) - return (entity_dfs, relationship_dfs) + entities = _merge_entities(entity_dfs) + relationships = _merge_relationships(relationship_dfs) + + return (entities, relationships) def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: """Load strategy method definition.""" match strategy_type: case ExtractEntityStrategyType.graph_intelligence: - from graphrag.index.operations.extract_entities.strategies.graph_intelligence import ( + from graphrag.index.operations.extract_entities.graph_intelligence_strategy import ( run_graph_intelligence, ) @@ -166,7 +157,7 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr case ExtractEntityStrategyType.nltk: bootstrap() # dynamically import nltk strategy to avoid dependency if not used - from graphrag.index.operations.extract_entities.strategies.nltk import ( + from graphrag.index.operations.extract_entities.nltk_strategy import ( run as run_nltk, ) @@ -174,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr case _: msg = f"Unknown strategy: {strategy_type}" raise ValueError(msg) + + +def _merge_entities(entity_dfs) -> pd.DataFrame: + all_entities = pd.concat(entity_dfs, ignore_index=True) + return ( + all_entities.groupby(["title", "type"], sort=False) + .agg(description=("description", list), text_unit_ids=("source_id", list)) + .reset_index() + ) + + +def _merge_relationships(relationship_dfs) -> pd.DataFrame: + all_relationships = pd.concat(relationship_dfs, ignore_index=False) + return ( + all_relationships.groupby(["source", "target"], sort=False) + .agg( + description=("description", list), + text_unit_ids=("source_id", list), + weight=("weight", "sum"), + ) + .reset_index() + ) diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/operations/extract_entities/graph_extractor.py similarity index 100% rename from graphrag/index/graph/extractors/graph/graph_extractor.py rename to graphrag/index/operations/extract_entities/graph_extractor.py diff --git a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py similarity index 95% rename from graphrag/index/operations/extract_entities/strategies/graph_intelligence.py rename to graphrag/index/operations/extract_entities/graph_intelligence_strategy.py index 308ff29d05..a91e0748d5 100644 --- a/graphrag/index/operations/extract_entities/strategies/graph_intelligence.py +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -9,9 +9,9 @@ import graphrag.config.defaults as defs from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors import GraphExtractor from graphrag.index.llm.load_llm import load_llm, read_llm_params -from graphrag.index.operations.extract_entities.strategies.typing import ( +from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor +from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractionResult, EntityTypes, @@ -106,7 +106,7 @@ async def run_extract_entities( ) entities = [ - ({"name": item[0], **(item[1] or {})}) + ({"title": item[0], **(item[1] or {})}) for item in graph.nodes(data=True) if item is not None ] diff --git a/graphrag/index/operations/extract_entities/strategies/nltk.py b/graphrag/index/operations/extract_entities/nltk_strategy.py similarity index 94% rename from graphrag/index/operations/extract_entities/strategies/nltk.py rename to graphrag/index/operations/extract_entities/nltk_strategy.py index 93a910cb36..81103c6955 100644 --- a/graphrag/index/operations/extract_entities/strategies/nltk.py +++ b/graphrag/index/operations/extract_entities/nltk_strategy.py @@ -9,7 +9,7 @@ from nltk.corpus import words from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.operations.extract_entities.strategies.typing import ( +from graphrag.index.operations.extract_entities.typing import ( Document, EntityExtractionResult, EntityTypes, @@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface return EntityExtractionResult( entities=[ - {"type": entity_type, "name": name} + {"type": entity_type, "title": name} for name, entity_type in entity_map.items() ], relationships=[], diff --git a/graphrag/index/operations/extract_entities/strategies/__init__.py b/graphrag/index/operations/extract_entities/strategies/__init__.py deleted file mode 100644 index f5cc17d750..0000000000 --- a/graphrag/index/operations/extract_entities/strategies/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine entities extraction strategies package root.""" diff --git a/graphrag/index/operations/extract_entities/strategies/typing.py b/graphrag/index/operations/extract_entities/typing.py similarity index 77% rename from graphrag/index/operations/extract_entities/strategies/typing.py rename to graphrag/index/operations/extract_entities/typing.py index aa253f7f1a..7eb2440674 100644 --- a/graphrag/index/operations/extract_entities/strategies/typing.py +++ b/graphrag/index/operations/extract_entities/typing.py @@ -5,6 +5,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass +from enum import Enum from typing import Any import networkx as nx @@ -45,3 +46,14 @@ class EntityExtractionResult: ], Awaitable[EntityExtractionResult], ] + + +class ExtractEntityStrategyType(str, Enum): + """ExtractEntityStrategyType class definition.""" + + graph_intelligence = "graph_intelligence" + nltk = "nltk" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' diff --git a/graphrag/index/operations/layout_graph/__init__.py b/graphrag/index/operations/layout_graph/__init__.py index 91a36cc087..cc80f99a97 100644 --- a/graphrag/index/operations/layout_graph/__init__.py +++ b/graphrag/index/operations/layout_graph/__init__.py @@ -2,9 +2,3 @@ # Licensed under the MIT License """The Indexing Engine graph layout package root.""" - -from graphrag.index.operations.layout_graph.layout_graph import ( - layout_graph, -) - -__all__ = ["layout_graph"] diff --git a/graphrag/index/operations/layout_graph/layout_graph.py b/graphrag/index/operations/layout_graph/layout_graph.py index 07424953e0..756fb4ff24 100644 --- a/graphrag/index/operations/layout_graph/layout_graph.py +++ b/graphrag/index/operations/layout_graph/layout_graph.py @@ -10,8 +10,8 @@ import pandas as pd from datashaper import VerbCallbacks -from graphrag.index.graph.visualization import GraphLayout -from graphrag.index.operations.embed_graph import NodeEmbeddings +from graphrag.index.operations.embed_graph.typing import NodeEmbeddings +from graphrag.index.operations.layout_graph.typing import GraphLayout class LayoutGraphStrategyType(str, Enum): @@ -81,7 +81,7 @@ def _run_layout( ) -> GraphLayout: match strategy: case LayoutGraphStrategyType.umap: - from graphrag.index.operations.layout_graph.methods.umap import ( + from graphrag.index.operations.layout_graph.umap import ( run as run_umap, ) @@ -92,7 +92,7 @@ def _run_layout( lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d), ) case LayoutGraphStrategyType.zero: - from graphrag.index.operations.layout_graph.methods.zero import ( + from graphrag.index.operations.layout_graph.zero import ( run as run_zero, ) diff --git a/graphrag/index/operations/layout_graph/methods/__init__.py b/graphrag/index/operations/layout_graph/methods/__init__.py deleted file mode 100644 index 5d5054122b..0000000000 --- a/graphrag/index/operations/layout_graph/methods/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Graph Layout Methods.""" diff --git a/graphrag/index/graph/visualization/typing.py b/graphrag/index/operations/layout_graph/typing.py similarity index 100% rename from graphrag/index/graph/visualization/typing.py rename to graphrag/index/operations/layout_graph/typing.py diff --git a/graphrag/index/operations/layout_graph/methods/umap.py b/graphrag/index/operations/layout_graph/umap.py similarity index 54% rename from graphrag/index/operations/layout_graph/methods/umap.py rename to graphrag/index/operations/layout_graph/umap.py index 636fd9a670..e5ab1668ca 100644 --- a/graphrag/index/operations/layout_graph/methods/umap.py +++ b/graphrag/index/operations/layout_graph/umap.py @@ -10,12 +10,11 @@ import networkx as nx import numpy as np -from graphrag.index.graph.visualization import ( +from graphrag.index.operations.embed_graph.typing import NodeEmbeddings +from graphrag.index.operations.layout_graph.typing import ( GraphLayout, NodePosition, - compute_umap_positions, ) -from graphrag.index.operations.embed_graph import NodeEmbeddings from graphrag.index.typing import ErrorHandlerFn # TODO: This could be handled more elegantly, like what columns to use @@ -80,3 +79,58 @@ def _filter_raw_embeddings(embeddings: NodeEmbeddings) -> NodeEmbeddings: for node_id, embedding in embeddings.items() if embedding is not None } + + +def compute_umap_positions( + embedding_vectors: np.ndarray, + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + min_dist: float = 0.75, + n_neighbors: int = 25, + spread: int = 1, + metric: str = "euclidean", + n_components: int = 2, + random_state: int = 86, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + # NOTE: This import is done here to reduce the initial import time of the graphrag package + import umap + + embedding_positions = umap.UMAP( + min_dist=min_dist, + n_neighbors=n_neighbors, + spread=spread, + n_components=n_components, + metric=metric, + random_state=random_state, + ).fit_transform(embedding_vectors) + + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_points = embedding_positions[index] # type: ignore + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if len(node_points) == 2: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=float(node_points[0]), + y=float(node_points[1]), + z=float(node_points[2]), + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data diff --git a/graphrag/index/operations/layout_graph/methods/zero.py b/graphrag/index/operations/layout_graph/zero.py similarity index 59% rename from graphrag/index/operations/layout_graph/methods/zero.py rename to graphrag/index/operations/layout_graph/zero.py index f41d2d4ca4..4bb7d39b00 100644 --- a/graphrag/index/operations/layout_graph/methods/zero.py +++ b/graphrag/index/operations/layout_graph/zero.py @@ -9,10 +9,9 @@ import networkx as nx -from graphrag.index.graph.visualization import ( +from graphrag.index.operations.layout_graph.typing import ( GraphLayout, NodePosition, - get_zero_positions, ) from graphrag.index.typing import ErrorHandlerFn @@ -61,3 +60,39 @@ def run( NodePosition(x=0, y=0, label=nodes[i], size=0, cluster=str(cluster)) ) return result + + +def get_zero_positions( + node_labels: list[str], + node_categories: list[int] | None = None, + node_sizes: list[int] | None = None, + three_d: bool | None = False, +) -> list[NodePosition]: + """Project embedding vectors down to 2D/3D using UMAP.""" + embedding_position_data: list[NodePosition] = [] + for index, node_name in enumerate(node_labels): + node_category = 1 if node_categories is None else node_categories[index] + node_size = 1 if node_sizes is None else node_sizes[index] + + if not three_d: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + else: + embedding_position_data.append( + NodePosition( + label=str(node_name), + x=0, + y=0, + z=0, + cluster=str(int(node_category)), + size=int(node_size), + ) + ) + return embedding_position_data diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py new file mode 100644 index 0000000000..5daf0df3cd --- /dev/null +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""The Indexing Engine community reports package root.""" + +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import ( + build_mixed_context, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( + CommunityReportsExtractor, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.prep_community_report_context import ( + prep_community_report_context, +) +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) + +__all__ = [ + "CommunityReportsExtractor", + "build_mixed_context", + "prep_community_report_context", + "schemas", + "sort_context", +] diff --git a/graphrag/index/graph/extractors/community_reports/build_mixed_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py similarity index 92% rename from graphrag/index/graph/extractors/community_reports/build_mixed_context.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py index ca10ca948d..32f19f772e 100644 --- a/graphrag/index/graph/extractors/community_reports/build_mixed_context.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/build_mixed_context.py @@ -4,8 +4,10 @@ import pandas as pd -import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.graph.extractors.community_reports.sort_context import sort_context +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) from graphrag.query.llm.text_utils import num_tokens diff --git a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py similarity index 100% rename from graphrag/index/graph/extractors/community_reports/community_reports_extractor.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py diff --git a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py similarity index 95% rename from graphrag/index/graph/extractors/community_reports/prep_community_report_context.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py index d2bf3a0f5d..bb5125f12f 100644 --- a/graphrag/index/graph/extractors/community_reports/prep_community_report_context.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/prep_community_report_context.py @@ -8,11 +8,13 @@ import pandas as pd -import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.graph.extractors.community_reports.build_mixed_context import ( +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.build_mixed_context import ( build_mixed_context, ) -from graphrag.index.graph.extractors.community_reports.sort_context import sort_context +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( + sort_context, +) from graphrag.index.utils.dataframes import ( antijoin, drop_columns, diff --git a/graphrag/index/graph/extractors/community_reports/schemas.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/schemas.py similarity index 100% rename from graphrag/index/graph/extractors/community_reports/schemas.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/schemas.py diff --git a/graphrag/index/graph/extractors/community_reports/sort_context.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py similarity index 98% rename from graphrag/index/graph/extractors/community_reports/sort_context.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py index ab56083438..cd17578b02 100644 --- a/graphrag/index/graph/extractors/community_reports/sort_context.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/sort_context.py @@ -4,7 +4,7 @@ import pandas as pd -import graphrag.index.graph.extractors.community_reports.schemas as schemas +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas from graphrag.query.llm.text_utils import num_tokens diff --git a/graphrag/index/graph/extractors/community_reports/utils.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py similarity index 81% rename from graphrag/index/graph/extractors/community_reports/utils.py rename to graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py index 48afd671b9..c847451b84 100644 --- a/graphrag/index/graph/extractors/community_reports/utils.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/utils.py @@ -5,7 +5,7 @@ import pandas as pd -import graphrag.index.graph.extractors.community_reports.schemas as schemas +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas def get_levels(df: pd.DataFrame, level_column: str = schemas.NODE_LEVEL) -> list[int]: diff --git a/graphrag/index/operations/summarize_communities/prepare_community_reports.py b/graphrag/index/operations/summarize_communities/prepare_community_reports.py index c2156edd69..45a6fec6d8 100644 --- a/graphrag/index/operations/summarize_communities/prepare_community_reports.py +++ b/graphrag/index/operations/summarize_communities/prepare_community_reports.py @@ -11,11 +11,13 @@ progress_iterable, ) -import graphrag.index.graph.extractors.community_reports.schemas as schemas -from graphrag.index.graph.extractors.community_reports.sort_context import ( +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas +from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import ( parallel_sort_context_batch, ) -from graphrag.index.graph.extractors.community_reports.utils import get_levels +from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import ( + get_levels, +) log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py index 2512db484f..cab9058074 100644 --- a/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py +++ b/graphrag/index/operations/summarize_communities/restore_community_hierarchy.py @@ -8,7 +8,7 @@ import pandas as pd -import graphrag.index.graph.extractors.community_reports.schemas as schemas +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas log = logging.getLogger(__name__) diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 8142572b6d..43adb07148 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -10,10 +10,10 @@ from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors.community_reports import ( +from graphrag.index.llm.load_llm import load_llm, read_llm_params +from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( CommunityReportsExtractor, ) -from graphrag.index.llm.load_llm import load_llm, read_llm_params from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, Finding, diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index 35a0192b82..d4c5c01072 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -15,12 +15,14 @@ ) import graphrag.config.defaults as defaults -import graphrag.index.graph.extractors.community_reports.schemas as schemas +import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors.community_reports import ( +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( prep_community_report_context, ) -from graphrag.index.graph.extractors.community_reports.utils import get_levels +from graphrag.index.operations.summarize_communities.community_reports_extractor.utils import ( + get_levels, +) from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, CommunityReportsStrategy, diff --git a/graphrag/index/graph/extractors/summarize/description_summary_extractor.py b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py similarity index 100% rename from graphrag/index/graph/extractors/summarize/description_summary_extractor.py rename to graphrag/index/operations/summarize_descriptions/description_summary_extractor.py diff --git a/graphrag/index/operations/summarize_descriptions/strategies.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py similarity index 94% rename from graphrag/index/operations/summarize_descriptions/strategies.py rename to graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py index 0538f0e225..4a22b9b554 100644 --- a/graphrag/index/operations/summarize_descriptions/strategies.py +++ b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py @@ -7,8 +7,10 @@ from fnllm import ChatLLM from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.graph.extractors.summarize import SummarizeExtractor from graphrag.index.llm.load_llm import load_llm, read_llm_params +from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( + SummarizeExtractor, +) from graphrag.index.operations.summarize_descriptions.typing import ( StrategyConfig, SummarizedDescriptionResult, diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index a8941b0c90..cf6650dd08 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -88,7 +88,7 @@ async def get_summarized( node_futures = [ do_summarize_descriptions( - str(row[1]["name"]), + str(row[1]["title"]), sorted(set(row[1]["description"])), ticker, semaphore, @@ -100,7 +100,7 @@ async def get_summarized( node_descriptions = [ { - "name": result.id, + "title": result.id, "description": result.description, } for result in node_results @@ -157,7 +157,7 @@ def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy """Load strategy method definition.""" match strategy_type: case SummarizeStrategyType.graph_intelligence: - from graphrag.index.operations.summarize_descriptions.strategies import ( + from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import ( run_graph_intelligence, ) diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 12792076f3..3c117ac837 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -12,7 +12,7 @@ from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.config.pipeline import PipelineConfig -from graphrag.index.operations.summarize_descriptions.strategies import ( +from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import ( run_graph_intelligence as run_entity_summarization, ) from graphrag.index.run.workflow import _find_workflow_config diff --git a/graphrag/index/graph/utils/stable_lcc.py b/graphrag/index/utils/stable_lcc.py similarity index 89% rename from graphrag/index/graph/utils/stable_lcc.py rename to graphrag/index/utils/stable_lcc.py index e5b0bf60b7..070311331b 100644 --- a/graphrag/index/graph/utils/stable_lcc.py +++ b/graphrag/index/utils/stable_lcc.py @@ -3,12 +3,11 @@ """A module for producing a stable largest connected component, i.e. same input graph == same output lcc.""" +import html from typing import Any, cast import networkx as nx -from graphrag.index.graph.utils.normalize_node_names import normalize_node_names - def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: """Return the largest connected component of the graph, with nodes and edges sorted in a stable way.""" @@ -60,3 +59,9 @@ def _get_edge_key(source: Any, target: Any) -> str: fixed_graph.add_edges_from(edges) return fixed_graph + + +def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: + """Normalize node names.""" + node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore + return nx.relabel_nodes(graph, node_mapping) diff --git a/graphrag/index/workflows/v1/compute_communities.py b/graphrag/index/workflows/v1/compute_communities.py index 46a739d392..543a6ece6a 100644 --- a/graphrag/index/workflows/v1/compute_communities.py +++ b/graphrag/index/workflows/v1/compute_communities.py @@ -14,6 +14,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep from graphrag.index.flows.compute_communities import compute_communities +from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "compute_communities" @@ -62,13 +63,19 @@ async def workflow( """All the steps to create the base entity graph.""" base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await compute_communities( + base_communities = compute_communities( base_relationship_edges, - storage, clustering_strategy=clustering_strategy, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_communities", base_communities) + if snapshot_transient_enabled: + await snapshot( + base_communities, + name="base_communities", + storage=storage, + formats=["parquet"], + ) + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index 37e849b8e2..d045f0d6e4 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -19,6 +19,7 @@ from graphrag.index.flows.create_base_text_units import ( create_base_text_units, ) +from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_base_text_units" @@ -65,17 +66,23 @@ async def workflow( """All the steps to transform base text_units.""" source = cast("pd.DataFrame", input.get_input()) - output = await create_base_text_units( + output = create_base_text_units( source, callbacks, - storage, chunk_by_columns, chunk_strategy=chunk_strategy, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_text_units", output) + if snapshot_transient_enabled: + await snapshot( + output, + name="create_base_text_units", + storage=storage, + formats=["parquet"], + ) + return create_verb_result( cast( "Table", diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index a278607e1e..603b03f75d 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -53,8 +53,7 @@ async def workflow( ) -> VerbResult: """All the steps to transform final relationships.""" base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - output = create_final_relationships(base_relationship_edges, base_entity_nodes) + output = create_final_relationships(base_relationship_edges) return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/extract_graph.py b/graphrag/index/workflows/v1/extract_graph.py index 86af232cfe..65016d6a6a 100644 --- a/graphrag/index/workflows/v1/extract_graph.py +++ b/graphrag/index/workflows/v1/extract_graph.py @@ -19,6 +19,9 @@ from graphrag.index.flows.extract_graph import ( extract_graph, ) +from graphrag.index.operations.create_graph import create_graph +from graphrag.index.operations.snapshot import snapshot +from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "extract_graph" @@ -90,18 +93,38 @@ async def workflow( text_units, callbacks, cache, - storage, extraction_strategy=extraction_strategy, extraction_num_threads=extraction_num_threads, extraction_async_mode=extraction_async_mode, entity_types=entity_types, summarization_strategy=summarization_strategy, summarization_num_threads=summarization_num_threads, - snapshot_graphml_enabled=snapshot_graphml_enabled, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_entity_nodes", base_entity_nodes) await runtime_storage.set("base_relationship_edges", base_relationship_edges) + if snapshot_graphml_enabled: + # todo: extract graphs at each level, and add in meta like descriptions + graph = create_graph(base_relationship_edges) + await snapshot_graphml( + graph, + name="graph", + storage=storage, + ) + + if snapshot_transient_enabled: + await snapshot( + base_entity_nodes, + name="base_entity_nodes", + storage=storage, + formats=["parquet"], + ) + await snapshot( + base_relationship_edges, + name="base_relationship_edges", + storage=storage, + formats=["parquet"], + ) + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 3fb865dce3..86b4c49a11 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -58,8 +58,8 @@ ], "nan_allowed_columns": [ "description", - "community", - "level" + "x", + "y" ], "subworkflows": 1, "max_runtime": 150, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 812603dd83..cf4180857c 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -76,8 +76,8 @@ ], "nan_allowed_columns": [ "description", - "community", - "level" + "x", + "y" ], "subworkflows": 1, "max_runtime": 150, diff --git a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py index 3610c55983..4f3a84f586 100644 --- a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py +++ b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py @@ -3,7 +3,9 @@ import math import platform -from graphrag.index.graph.extractors.community_reports import sort_context +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( + sort_context, +) from graphrag.query.llm.text_utils import num_tokens nan = math.nan diff --git a/tests/unit/indexing/graph/utils/test_stable_lcc.py b/tests/unit/indexing/graph/utils/test_stable_lcc.py index 02ddc2989a..244f3b905d 100644 --- a/tests/unit/indexing/graph/utils/test_stable_lcc.py +++ b/tests/unit/indexing/graph/utils/test_stable_lcc.py @@ -4,7 +4,7 @@ import networkx as nx -from graphrag.index.graph.utils.stable_lcc import stable_largest_connected_component +from graphrag.index.utils.stable_lcc import stable_largest_connected_component class TestStableLCC(unittest.TestCase): diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index d72f10b344..0403bee4d1 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -2,10 +2,10 @@ # Licensed under the MIT License import unittest -from graphrag.index.operations.extract_entities.strategies.graph_intelligence import ( +from graphrag.index.operations.extract_entities.graph_intelligence_strategy import ( run_extract_entities, ) -from graphrag.index.operations.extract_entities.strategies.typing import ( +from graphrag.index.operations.extract_entities.typing import ( Document, ) from tests.unit.indexing.verbs.helpers.mock_llm import create_mock_llm @@ -42,7 +42,7 @@ async def test_run_extract_entities_single_document_correct_entities_returned(se # self.assertItemsEqual isn't available yet, or I am just silly # so we sort the lists and compare them assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["name"] for entity in results.entities + entity["title"] for entity in results.entities ]) async def test_run_extract_entities_multiple_documents_correct_entities_returned( @@ -81,7 +81,7 @@ async def test_run_extract_entities_multiple_documents_correct_entities_returned # self.assertItemsEqual isn't available yet, or I am just silly # so we sort the lists and compare them assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["name"] for entity in results.entities + entity["title"] for entity in results.entities ]) async def test_run_extract_entities_multiple_documents_correct_edges_returned(self): diff --git a/tests/verbs/data/base_communities.parquet b/tests/verbs/data/base_communities.parquet index d9c32028ae..93edaa4504 100644 Binary files a/tests/verbs/data/base_communities.parquet and b/tests/verbs/data/base_communities.parquet differ diff --git a/tests/verbs/data/base_entity_nodes.parquet b/tests/verbs/data/base_entity_nodes.parquet index 0240e7835b..2b13b1db5c 100644 Binary files a/tests/verbs/data/base_entity_nodes.parquet and b/tests/verbs/data/base_entity_nodes.parquet differ diff --git a/tests/verbs/data/base_relationship_edges.parquet b/tests/verbs/data/base_relationship_edges.parquet index 4f1e302550..4f2b4d837d 100644 Binary files a/tests/verbs/data/base_relationship_edges.parquet and b/tests/verbs/data/base_relationship_edges.parquet differ diff --git a/tests/verbs/data/create_base_text_units.parquet b/tests/verbs/data/create_base_text_units.parquet index 4e1ae106f4..2f8c7f9037 100644 Binary files a/tests/verbs/data/create_base_text_units.parquet and b/tests/verbs/data/create_base_text_units.parquet differ diff --git a/tests/verbs/data/create_final_communities.parquet b/tests/verbs/data/create_final_communities.parquet index 6c170751c8..dbb0dd8b22 100644 Binary files a/tests/verbs/data/create_final_communities.parquet and b/tests/verbs/data/create_final_communities.parquet differ diff --git a/tests/verbs/data/create_final_community_reports.parquet b/tests/verbs/data/create_final_community_reports.parquet index 5b29215f50..a04e161573 100644 Binary files a/tests/verbs/data/create_final_community_reports.parquet and b/tests/verbs/data/create_final_community_reports.parquet differ diff --git a/tests/verbs/data/create_final_covariates.parquet b/tests/verbs/data/create_final_covariates.parquet index 64b8928b4b..798098cd15 100644 Binary files a/tests/verbs/data/create_final_covariates.parquet and b/tests/verbs/data/create_final_covariates.parquet differ diff --git a/tests/verbs/data/create_final_documents.parquet b/tests/verbs/data/create_final_documents.parquet index e90c863e74..33cb2ae3b8 100644 Binary files a/tests/verbs/data/create_final_documents.parquet and b/tests/verbs/data/create_final_documents.parquet differ diff --git a/tests/verbs/data/create_final_entities.parquet b/tests/verbs/data/create_final_entities.parquet index 0938f79236..4c7cd2db4f 100644 Binary files a/tests/verbs/data/create_final_entities.parquet and b/tests/verbs/data/create_final_entities.parquet differ diff --git a/tests/verbs/data/create_final_nodes.parquet b/tests/verbs/data/create_final_nodes.parquet index 5ac4fea26c..5a3afdaa6f 100644 Binary files a/tests/verbs/data/create_final_nodes.parquet and b/tests/verbs/data/create_final_nodes.parquet differ diff --git a/tests/verbs/data/create_final_relationships.parquet b/tests/verbs/data/create_final_relationships.parquet index 9b8b16c8c0..911b3c2b4a 100644 Binary files a/tests/verbs/data/create_final_relationships.parquet and b/tests/verbs/data/create_final_relationships.parquet differ diff --git a/tests/verbs/data/create_final_text_units.parquet b/tests/verbs/data/create_final_text_units.parquet index 8853779299..f5766acb25 100644 Binary files a/tests/verbs/data/create_final_text_units.parquet and b/tests/verbs/data/create_final_text_units.parquet differ diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py index 07db7d42c3..5c91fc46b7 100644 --- a/tests/verbs/test_compute_communities.py +++ b/tests/verbs/test_compute_communities.py @@ -4,7 +4,6 @@ from graphrag.index.flows.compute_communities import ( compute_communities, ) -from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.compute_communities import ( workflow_name, ) @@ -16,37 +15,15 @@ ) -async def test_compute_communities(): +def test_compute_communities(): edges = load_test_table("base_relationship_edges") expected = load_test_table("base_communities") - context = create_run_context(None, None, None) config = get_config_for_workflow(workflow_name) clustering_strategy = config["cluster_graph"]["strategy"] - actual = await compute_communities( - edges, storage=context.storage, clustering_strategy=clustering_strategy - ) + actual = compute_communities(edges, clustering_strategy=clustering_strategy) columns = list(expected.columns.values) compare_outputs(actual, expected, columns) assert len(actual.columns) == len(expected.columns) - - -async def test_compute_communities_with_snapshots(): - edges = load_test_table("base_relationship_edges") - - context = create_run_context(None, None, None) - config = get_config_for_workflow(workflow_name) - clustering_strategy = config["cluster_graph"]["strategy"] - - await compute_communities( - edges, - storage=context.storage, - clustering_strategy=clustering_strategy, - snapshot_transient_enabled=True, - ) - - assert context.storage.keys() == [ - "base_communities.parquet", - ], "Community snapshot keys differ" diff --git a/tests/verbs/test_create_final_community_reports.py b/tests/verbs/test_create_final_community_reports.py index 1a04e9c56f..85a6c3ee2b 100644 --- a/tests/verbs/test_create_final_community_reports.py +++ b/tests/verbs/test_create_final_community_reports.py @@ -6,7 +6,7 @@ from datashaper.errors import VerbParallelizationError from graphrag.config.enums import LLMType -from graphrag.index.graph.extractors.community_reports.community_reports_extractor import ( +from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import ( CommunityReportResponse, FindingModel, ) diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 3d3a4faf3b..9f01e08304 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -16,10 +16,9 @@ def test_create_final_relationships(): edges = load_test_table("base_relationship_edges") - nodes = load_test_table("base_entity_nodes") expected = load_test_table(workflow_name) - actual = create_final_relationships(edges, nodes) + actual = create_final_relationships(edges) assert "id" in expected.columns columns = list(expected.columns.values)