diff --git a/.semversioner/next-release/patch-20241212190223784600.json b/.semversioner/next-release/patch-20241212190223784600.json new file mode 100644 index 000000000..d54d621cf --- /dev/null +++ b/.semversioner/next-release/patch-20241212190223784600.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Streamline flows." +} diff --git a/graphrag/index/flows/compute_communities.py b/graphrag/index/flows/compute_communities.py index 09ec084ac..6ca74ded4 100644 --- a/graphrag/index/flows/compute_communities.py +++ b/graphrag/index/flows/compute_communities.py @@ -9,15 +9,11 @@ from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.create_graph import create_graph -from graphrag.index.operations.snapshot import snapshot -from graphrag.storage.pipeline_storage import PipelineStorage -async def compute_communities( +def compute_communities( base_relationship_edges: pd.DataFrame, - storage: PipelineStorage, clustering_strategy: dict[str, Any], - snapshot_transient_enabled: bool = False, ) -> pd.DataFrame: """All the steps to create the base entity graph.""" graph = create_graph(base_relationship_edges) @@ -32,12 +28,4 @@ async def compute_communities( ).explode("title") base_communities["community"] = base_communities["community"].astype(int) - if snapshot_transient_enabled: - await snapshot( - base_communities, - name="base_communities", - storage=storage, - formats=["parquet"], - ) - return base_communities diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index f699d507b..3204425d1 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -15,18 +15,14 @@ ) from graphrag.index.operations.chunk_text import chunk_text -from graphrag.index.operations.snapshot import snapshot from graphrag.index.utils.hashing import gen_sha512_hash -from graphrag.storage.pipeline_storage import PipelineStorage -async def create_base_text_units( +def create_base_text_units( documents: pd.DataFrame, callbacks: VerbCallbacks, - storage: PipelineStorage, chunk_by_columns: list[str], chunk_strategy: dict[str, Any] | None = None, - snapshot_transient_enabled: bool = False, ) -> pd.DataFrame: """All the steps to transform base text_units.""" sort = documents.sort_values(by=["id"], ascending=[True]) @@ -74,19 +70,7 @@ async def create_base_text_units( # rename for downstream consumption chunked.rename(columns={"chunk": "text"}, inplace=True) - output = cast( - "pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True) - ) - - if snapshot_transient_enabled: - await snapshot( - output, - name="create_base_text_units", - storage=storage, - formats=["parquet"], - ) - - return output + return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)) # TODO: would be nice to inline this completely in the main method with pandas diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index a468e830a..0b6932e40 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -10,6 +10,7 @@ VerbCallbacks, ) +from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.embed_graph.embed_graph import embed_graph from graphrag.index.operations.layout_graph.layout_graph import layout_graph @@ -37,15 +38,19 @@ def create_final_nodes( layout_strategy, embeddings=graph_embeddings, ) - nodes = base_entity_nodes.merge( - layout, left_on="title", right_on="label", how="left" - ) - joined = nodes.merge(base_communities, on="title", how="left") - joined["level"] = joined["level"].fillna(0).astype(int) - joined["community"] = joined["community"].fillna(-1).astype(int) + degrees = compute_degree(graph) - return joined.loc[ + nodes = ( + base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left") + .merge(degrees, on="title", how="left") + .merge(base_communities, on="title", how="left") + ) + nodes["level"] = nodes["level"].fillna(0).astype(int) + nodes["community"] = nodes["community"].fillna(-1).astype(int) + # disconnected nodes and those with no community even at level 0 can be missing degree + nodes["degree"] = nodes["degree"].fillna(0).astype(int) + return nodes.loc[ :, [ "id", diff --git a/graphrag/index/flows/create_final_relationships.py b/graphrag/index/flows/create_final_relationships.py index 5f58d1ffb..03f6e362e 100644 --- a/graphrag/index/flows/create_final_relationships.py +++ b/graphrag/index/flows/create_final_relationships.py @@ -5,20 +5,25 @@ import pandas as pd +from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.compute_edge_combined_degree import ( compute_edge_combined_degree, ) +from graphrag.index.operations.create_graph import create_graph def create_final_relationships( base_relationship_edges: pd.DataFrame, - base_entity_nodes: pd.DataFrame, ) -> pd.DataFrame: """All the steps to transform final relationships.""" relationships = base_relationship_edges + + graph = create_graph(base_relationship_edges) + degrees = compute_degree(graph) + relationships["combined_degree"] = compute_edge_combined_degree( relationships, - base_entity_nodes, + degrees, node_name_column="title", node_degree_column="degree", edge_source_column="source", diff --git a/graphrag/index/flows/extract_graph.py b/graphrag/index/flows/extract_graph.py index db73635b8..87e369f52 100644 --- a/graphrag/index/flows/extract_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -6,7 +6,6 @@ from typing import Any from uuid import uuid4 -import networkx as nx import pandas as pd from datashaper import ( AsyncType, @@ -14,33 +13,26 @@ ) from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.extract_entities import extract_entities -from graphrag.index.operations.snapshot import snapshot -from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.index.operations.summarize_descriptions import ( summarize_descriptions, ) -from graphrag.storage.pipeline_storage import PipelineStorage async def extract_graph( text_units: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, - storage: PipelineStorage, extraction_strategy: dict[str, Any] | None = None, extraction_num_threads: int = 4, extraction_async_mode: AsyncType = AsyncType.AsyncIO, entity_types: list[str] | None = None, summarization_strategy: dict[str, Any] | None = None, summarization_num_threads: int = 4, - snapshot_graphml_enabled: bool = False, - snapshot_transient_enabled: bool = False, ) -> tuple[pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later - entity_dfs, relationship_dfs = await extract_entities( + entities, relationships = await extract_entities( text_units, callbacks, cache, @@ -52,87 +44,38 @@ async def extract_graph( num_threads=extraction_num_threads, ) - if not _validate_data(entity_dfs): + if not _validate_data(entities): error_msg = "Entity Extraction failed. No entities detected during extraction." callbacks.error(error_msg) raise ValueError(error_msg) - if not _validate_data(relationship_dfs): + if not _validate_data(relationships): error_msg = ( "Entity Extraction failed. No relationships detected during extraction." ) callbacks.error(error_msg) raise ValueError(error_msg) - merged_entities = _merge_entities(entity_dfs) - merged_relationships = _merge_relationships(relationship_dfs) - entity_summaries, relationship_summaries = await summarize_descriptions( - merged_entities, - merged_relationships, + entities, + relationships, callbacks, cache, strategy=summarization_strategy, num_threads=summarization_num_threads, ) - base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries) - - graph = create_graph(base_relationship_edges) - - base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph) - - if snapshot_graphml_enabled: - # todo: extract graphs at each level, and add in meta like descriptions - await snapshot_graphml( - graph, - name="graph", - storage=storage, - ) + base_relationship_edges = _prep_edges(relationships, relationship_summaries) - if snapshot_transient_enabled: - await snapshot( - base_entity_nodes, - name="base_entity_nodes", - storage=storage, - formats=["parquet"], - ) - await snapshot( - base_relationship_edges, - name="base_relationship_edges", - storage=storage, - formats=["parquet"], - ) + base_entity_nodes = _prep_nodes(entities, entity_summaries) return (base_entity_nodes, base_relationship_edges) -def _merge_entities(entity_dfs) -> pd.DataFrame: - all_entities = pd.concat(entity_dfs, ignore_index=True) - return ( - all_entities.groupby(["name", "type"], sort=False) - .agg({"description": list, "source_id": list}) - .reset_index() - ) - - -def _merge_relationships(relationship_dfs) -> pd.DataFrame: - all_relationships = pd.concat(relationship_dfs, ignore_index=False) - return ( - all_relationships.groupby(["source", "target"], sort=False) - .agg({"description": list, "source_id": list, "weight": "sum"}) - .reset_index() - ) - - -def _prep_nodes(entities, summaries, graph) -> pd.DataFrame: - degrees_df = _compute_degree(graph) +def _prep_nodes(entities, summaries) -> pd.DataFrame: entities.drop(columns=["description"], inplace=True) - nodes = ( - entities.merge(summaries, on="name", how="left") - .merge(degrees_df, on="name") - .drop_duplicates(subset="name") - .rename(columns={"name": "title", "source_id": "text_unit_ids"}) + nodes = entities.merge(summaries, on="title", how="left").drop_duplicates( + subset="title" ) nodes = nodes.loc[nodes["title"].notna()].reset_index() nodes["human_readable_id"] = nodes.index @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame: relationships.drop(columns=["description"]) .drop_duplicates(subset=["source", "target"]) .merge(summaries, on=["source", "target"], how="left") - .rename(columns={"source_id": "text_unit_ids"}) ) edges["human_readable_id"] = edges.index edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4())) return edges -def _compute_degree(graph: nx.Graph) -> pd.DataFrame: - return pd.DataFrame([ - {"name": node, "degree": int(degree)} - for node, degree in graph.degree # type: ignore - ]) - - -def _validate_data(df_list: list[pd.DataFrame]) -> bool: - """Validate that the dataframe list is valid. At least one dataframe must contain data.""" - return any( - len(df) > 0 for df in df_list - ) # Check for len, not .empty, as the dfs have schemas in some cases +def _validate_data(df: pd.DataFrame) -> bool: + """Validate that the dataframe has data.""" + return len(df) > 0 diff --git a/graphrag/index/flows/generate_text_embeddings.py b/graphrag/index/flows/generate_text_embeddings.py index f701c464f..877966dab 100644 --- a/graphrag/index/flows/generate_text_embeddings.py +++ b/graphrag/index/flows/generate_text_embeddings.py @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings( strategy=text_embed_config["strategy"], ) - data = data.loc[:, ["id", "embedding"]] - if snapshot_embeddings_enabled is True: + data = data.loc[:, ["id", "embedding"]] await snapshot( data, name=f"embeddings.{name}", diff --git a/graphrag/index/operations/compute_degree.py b/graphrag/index/operations/compute_degree.py new file mode 100644 index 000000000..b720bf6de --- /dev/null +++ b/graphrag/index/operations/compute_degree.py @@ -0,0 +1,15 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing create_graph definition.""" + +import networkx as nx +import pandas as pd + + +def compute_degree(graph: nx.Graph) -> pd.DataFrame: + """Create a new DataFrame with the degree of each node in the graph.""" + return pd.DataFrame([ + {"title": node, "degree": int(degree)} + for node, degree in graph.degree # type: ignore + ]) diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index 89598822f..3245c2481 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -37,7 +37,7 @@ async def extract_entities( async_mode: AsyncType = AsyncType.AsyncIO, entity_types=DEFAULT_ENTITY_TYPES, num_threads: int = 4, -) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]: +) -> tuple[pd.DataFrame, pd.DataFrame]: """ Extract entities from a piece of text. @@ -138,7 +138,10 @@ async def run_strategy(row): entity_dfs.append(pd.DataFrame(result[0])) relationship_dfs.append(pd.DataFrame(result[1])) - return (entity_dfs, relationship_dfs) + entities = _merge_entities(entity_dfs) + relationships = _merge_relationships(relationship_dfs) + + return (entities, relationships) def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: @@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr case _: msg = f"Unknown strategy: {strategy_type}" raise ValueError(msg) + + +def _merge_entities(entity_dfs) -> pd.DataFrame: + all_entities = pd.concat(entity_dfs, ignore_index=True) + return ( + all_entities.groupby(["title", "type"], sort=False) + .agg(description=("description", list), text_unit_ids=("source_id", list)) + .reset_index() + ) + + +def _merge_relationships(relationship_dfs) -> pd.DataFrame: + all_relationships = pd.concat(relationship_dfs, ignore_index=False) + return ( + all_relationships.groupby(["source", "target"], sort=False) + .agg( + description=("description", list), + text_unit_ids=("source_id", list), + weight=("weight", "sum"), + ) + .reset_index() + ) diff --git a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py index fa30d6c0c..a91e0748d 100644 --- a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -106,7 +106,7 @@ async def run_extract_entities( ) entities = [ - ({"name": item[0], **(item[1] or {})}) + ({"title": item[0], **(item[1] or {})}) for item in graph.nodes(data=True) if item is not None ] diff --git a/graphrag/index/operations/extract_entities/nltk_strategy.py b/graphrag/index/operations/extract_entities/nltk_strategy.py index 13bc6742f..81103c695 100644 --- a/graphrag/index/operations/extract_entities/nltk_strategy.py +++ b/graphrag/index/operations/extract_entities/nltk_strategy.py @@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface return EntityExtractionResult( entities=[ - {"type": entity_type, "name": name} + {"type": entity_type, "title": name} for name, entity_type in entity_map.items() ], relationships=[], diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index bfe93f441..cf6650dd0 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -88,7 +88,7 @@ async def get_summarized( node_futures = [ do_summarize_descriptions( - str(row[1]["name"]), + str(row[1]["title"]), sorted(set(row[1]["description"])), ticker, semaphore, @@ -100,7 +100,7 @@ async def get_summarized( node_descriptions = [ { - "name": result.id, + "title": result.id, "description": result.description, } for result in node_results diff --git a/graphrag/index/workflows/v1/compute_communities.py b/graphrag/index/workflows/v1/compute_communities.py index 46a739d39..543a6ece6 100644 --- a/graphrag/index/workflows/v1/compute_communities.py +++ b/graphrag/index/workflows/v1/compute_communities.py @@ -14,6 +14,7 @@ from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep from graphrag.index.flows.compute_communities import compute_communities +from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "compute_communities" @@ -62,13 +63,19 @@ async def workflow( """All the steps to create the base entity graph.""" base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await compute_communities( + base_communities = compute_communities( base_relationship_edges, - storage, clustering_strategy=clustering_strategy, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_communities", base_communities) + if snapshot_transient_enabled: + await snapshot( + base_communities, + name="base_communities", + storage=storage, + formats=["parquet"], + ) + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index 37e849b8e..d045f0d6e 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -19,6 +19,7 @@ from graphrag.index.flows.create_base_text_units import ( create_base_text_units, ) +from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_base_text_units" @@ -65,17 +66,23 @@ async def workflow( """All the steps to transform base text_units.""" source = cast("pd.DataFrame", input.get_input()) - output = await create_base_text_units( + output = create_base_text_units( source, callbacks, - storage, chunk_by_columns, chunk_strategy=chunk_strategy, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_text_units", output) + if snapshot_transient_enabled: + await snapshot( + output, + name="create_base_text_units", + storage=storage, + formats=["parquet"], + ) + return create_verb_result( cast( "Table", diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index a278607e1..603b03f75 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -53,8 +53,7 @@ async def workflow( ) -> VerbResult: """All the steps to transform final relationships.""" base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - output = create_final_relationships(base_relationship_edges, base_entity_nodes) + output = create_final_relationships(base_relationship_edges) return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/extract_graph.py b/graphrag/index/workflows/v1/extract_graph.py index 86af232cf..65016d6a6 100644 --- a/graphrag/index/workflows/v1/extract_graph.py +++ b/graphrag/index/workflows/v1/extract_graph.py @@ -19,6 +19,9 @@ from graphrag.index.flows.extract_graph import ( extract_graph, ) +from graphrag.index.operations.create_graph import create_graph +from graphrag.index.operations.snapshot import snapshot +from graphrag.index.operations.snapshot_graphml import snapshot_graphml from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "extract_graph" @@ -90,18 +93,38 @@ async def workflow( text_units, callbacks, cache, - storage, extraction_strategy=extraction_strategy, extraction_num_threads=extraction_num_threads, extraction_async_mode=extraction_async_mode, entity_types=entity_types, summarization_strategy=summarization_strategy, summarization_num_threads=summarization_num_threads, - snapshot_graphml_enabled=snapshot_graphml_enabled, - snapshot_transient_enabled=snapshot_transient_enabled, ) await runtime_storage.set("base_entity_nodes", base_entity_nodes) await runtime_storage.set("base_relationship_edges", base_relationship_edges) + if snapshot_graphml_enabled: + # todo: extract graphs at each level, and add in meta like descriptions + graph = create_graph(base_relationship_edges) + await snapshot_graphml( + graph, + name="graph", + storage=storage, + ) + + if snapshot_transient_enabled: + await snapshot( + base_entity_nodes, + name="base_entity_nodes", + storage=storage, + formats=["parquet"], + ) + await snapshot( + base_relationship_edges, + name="base_relationship_edges", + storage=storage, + formats=["parquet"], + ) + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 3fb865dce..86b4c49a1 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -58,8 +58,8 @@ ], "nan_allowed_columns": [ "description", - "community", - "level" + "x", + "y" ], "subworkflows": 1, "max_runtime": 150, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 812603dd8..cf4180857 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -76,8 +76,8 @@ ], "nan_allowed_columns": [ "description", - "community", - "level" + "x", + "y" ], "subworkflows": 1, "max_runtime": 150, diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index a5336d24b..0403bee4d 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -42,7 +42,7 @@ async def test_run_extract_entities_single_document_correct_entities_returned(se # self.assertItemsEqual isn't available yet, or I am just silly # so we sort the lists and compare them assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["name"] for entity in results.entities + entity["title"] for entity in results.entities ]) async def test_run_extract_entities_multiple_documents_correct_entities_returned( @@ -81,7 +81,7 @@ async def test_run_extract_entities_multiple_documents_correct_entities_returned # self.assertItemsEqual isn't available yet, or I am just silly # so we sort the lists and compare them assert sorted(["TEST_ENTITY_1", "TEST_ENTITY_2", "TEST_ENTITY_3"]) == sorted([ - entity["name"] for entity in results.entities + entity["title"] for entity in results.entities ]) async def test_run_extract_entities_multiple_documents_correct_edges_returned(self): diff --git a/tests/verbs/data/base_communities.parquet b/tests/verbs/data/base_communities.parquet index d9c32028a..93edaa450 100644 Binary files a/tests/verbs/data/base_communities.parquet and b/tests/verbs/data/base_communities.parquet differ diff --git a/tests/verbs/data/base_entity_nodes.parquet b/tests/verbs/data/base_entity_nodes.parquet index 0240e7835..2b13b1db5 100644 Binary files a/tests/verbs/data/base_entity_nodes.parquet and b/tests/verbs/data/base_entity_nodes.parquet differ diff --git a/tests/verbs/data/base_relationship_edges.parquet b/tests/verbs/data/base_relationship_edges.parquet index 4f1e30255..4f2b4d837 100644 Binary files a/tests/verbs/data/base_relationship_edges.parquet and b/tests/verbs/data/base_relationship_edges.parquet differ diff --git a/tests/verbs/data/create_base_text_units.parquet b/tests/verbs/data/create_base_text_units.parquet index 4e1ae106f..2f8c7f903 100644 Binary files a/tests/verbs/data/create_base_text_units.parquet and b/tests/verbs/data/create_base_text_units.parquet differ diff --git a/tests/verbs/data/create_final_communities.parquet b/tests/verbs/data/create_final_communities.parquet index 6c170751c..dbb0dd8b2 100644 Binary files a/tests/verbs/data/create_final_communities.parquet and b/tests/verbs/data/create_final_communities.parquet differ diff --git a/tests/verbs/data/create_final_community_reports.parquet b/tests/verbs/data/create_final_community_reports.parquet index 5b29215f5..a04e16157 100644 Binary files a/tests/verbs/data/create_final_community_reports.parquet and b/tests/verbs/data/create_final_community_reports.parquet differ diff --git a/tests/verbs/data/create_final_covariates.parquet b/tests/verbs/data/create_final_covariates.parquet index 64b8928b4..798098cd1 100644 Binary files a/tests/verbs/data/create_final_covariates.parquet and b/tests/verbs/data/create_final_covariates.parquet differ diff --git a/tests/verbs/data/create_final_documents.parquet b/tests/verbs/data/create_final_documents.parquet index e90c863e7..33cb2ae3b 100644 Binary files a/tests/verbs/data/create_final_documents.parquet and b/tests/verbs/data/create_final_documents.parquet differ diff --git a/tests/verbs/data/create_final_entities.parquet b/tests/verbs/data/create_final_entities.parquet index 0938f7923..4c7cd2db4 100644 Binary files a/tests/verbs/data/create_final_entities.parquet and b/tests/verbs/data/create_final_entities.parquet differ diff --git a/tests/verbs/data/create_final_nodes.parquet b/tests/verbs/data/create_final_nodes.parquet index 5ac4fea26..5a3afdaa6 100644 Binary files a/tests/verbs/data/create_final_nodes.parquet and b/tests/verbs/data/create_final_nodes.parquet differ diff --git a/tests/verbs/data/create_final_relationships.parquet b/tests/verbs/data/create_final_relationships.parquet index 9b8b16c8c..911b3c2b4 100644 Binary files a/tests/verbs/data/create_final_relationships.parquet and b/tests/verbs/data/create_final_relationships.parquet differ diff --git a/tests/verbs/data/create_final_text_units.parquet b/tests/verbs/data/create_final_text_units.parquet index 885377929..f5766acb2 100644 Binary files a/tests/verbs/data/create_final_text_units.parquet and b/tests/verbs/data/create_final_text_units.parquet differ diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py index 07db7d42c..5c91fc46b 100644 --- a/tests/verbs/test_compute_communities.py +++ b/tests/verbs/test_compute_communities.py @@ -4,7 +4,6 @@ from graphrag.index.flows.compute_communities import ( compute_communities, ) -from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.compute_communities import ( workflow_name, ) @@ -16,37 +15,15 @@ ) -async def test_compute_communities(): +def test_compute_communities(): edges = load_test_table("base_relationship_edges") expected = load_test_table("base_communities") - context = create_run_context(None, None, None) config = get_config_for_workflow(workflow_name) clustering_strategy = config["cluster_graph"]["strategy"] - actual = await compute_communities( - edges, storage=context.storage, clustering_strategy=clustering_strategy - ) + actual = compute_communities(edges, clustering_strategy=clustering_strategy) columns = list(expected.columns.values) compare_outputs(actual, expected, columns) assert len(actual.columns) == len(expected.columns) - - -async def test_compute_communities_with_snapshots(): - edges = load_test_table("base_relationship_edges") - - context = create_run_context(None, None, None) - config = get_config_for_workflow(workflow_name) - clustering_strategy = config["cluster_graph"]["strategy"] - - await compute_communities( - edges, - storage=context.storage, - clustering_strategy=clustering_strategy, - snapshot_transient_enabled=True, - ) - - assert context.storage.keys() == [ - "base_communities.parquet", - ], "Community snapshot keys differ" diff --git a/tests/verbs/test_create_final_relationships.py b/tests/verbs/test_create_final_relationships.py index 3d3a4faf3..9f01e0830 100644 --- a/tests/verbs/test_create_final_relationships.py +++ b/tests/verbs/test_create_final_relationships.py @@ -16,10 +16,9 @@ def test_create_final_relationships(): edges = load_test_table("base_relationship_edges") - nodes = load_test_table("base_entity_nodes") expected = load_test_table(workflow_name) - actual = create_final_relationships(edges, nodes) + actual = create_final_relationships(edges) assert "id" in expected.columns columns = list(expected.columns.values)