Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flow cleanup #1510

Merged
merged 16 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212190223784600.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}
14 changes: 1 addition & 13 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage


async def compute_communities(
def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
Expand All @@ -32,12 +28,4 @@ async def compute_communities(
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)

if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return base_communities
20 changes: 2 additions & 18 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage


async def create_base_text_units(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand Down Expand Up @@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)

output = cast(
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)

return output
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))
natoverse marked this conversation as resolved.
Show resolved Hide resolved


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
19 changes: 12 additions & 7 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
VerbCallbacks,
)

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
Expand Down Expand Up @@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy,
embeddings=graph_embeddings,
)
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left"
)

joined = nodes.merge(base_communities, on="title", how="left")
joined["level"] = joined["level"].fillna(0).astype(int)
joined["community"] = joined["community"].fillna(-1).astype(int)
degrees = compute_degree(graph)

return joined.loc[
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
)
nodes["level"] = nodes["level"].fillna(0).astype(int)
nodes["community"] = nodes["community"].fillna(-1).astype(int)
# disconnected nodes and those with no community even at level 0 can be missing degree
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
:,
[
"id",
Expand Down
9 changes: 7 additions & 2 deletions graphrag/index/flows/create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@

import pandas as pd

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.create_graph import create_graph


def create_final_relationships(
base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
relationships = base_relationship_edges

graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)

relationships["combined_degree"] = compute_edge_combined_degree(
relationships,
base_entity_nodes,
degrees,
node_name_column="title",
node_degree_column="degree",
edge_source_column="source",
Expand Down
93 changes: 13 additions & 80 deletions graphrag/index/flows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,33 @@
from typing import Any
from uuid import uuid4

import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage


async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
Expand All @@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads,
)

if not _validate_data(entity_dfs):
if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)

if not _validate_data(relationship_dfs):
if not _validate_data(relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)

merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
entities,
relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)

graph = create_graph(base_relationship_edges)

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
base_relationship_edges = _prep_edges(relationships, relationship_summaries)

if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
base_entity_nodes = _prep_nodes(entities, entity_summaries)

return (base_entity_nodes, base_relationship_edges)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)


def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
def _prep_nodes(entities, summaries) -> pd.DataFrame:
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
subset="title"
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
Expand All @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])


def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
def _validate_data(df: pd.DataFrame) -> bool:
"""Validate that the dataframe has data."""
return len(df) > 0
3 changes: 1 addition & 2 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"],
)

data = data.loc[:, ["id", "embedding"]]

if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot(
data,
name=f"embeddings.{name}",
Expand Down
15 changes: 15 additions & 0 deletions graphrag/index/operations/compute_degree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing create_graph definition."""

import networkx as nx
import pandas as pd


def compute_degree(graph: nx.Graph) -> pd.DataFrame:
"""Create a new DataFrame with the degree of each node in the graph."""
return pd.DataFrame([
{"title": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])
29 changes: 27 additions & 2 deletions graphrag/index/operations/extract_entities/extract_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def extract_entities(
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
num_threads: int = 4,
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Extract entities from a piece of text.

Expand Down Expand Up @@ -138,7 +138,10 @@ async def run_strategy(row):
entity_dfs.append(pd.DataFrame(result[0]))
relationship_dfs.append(pd.DataFrame(result[1]))

return (entity_dfs, relationship_dfs)
entities = _merge_entities(entity_dfs)
relationships = _merge_relationships(relationship_dfs)

return (entities, relationships)


def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
Expand All @@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
case _:
msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["title", "type"], sort=False)
.agg(description=("description", list), text_unit_ids=("source_id", list))
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg(
description=("description", list),
text_unit_ids=("source_id", list),
weight=("weight", "sum"),
)
.reset_index()
)
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def run_extract_entities(
)

entities = [
({"name": item[0], **(item[1] or {})})
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface

return EntityExtractionResult(
entities=[
{"type": entity_type, "name": name}
{"type": entity_type, "title": name}
for name, entity_type in entity_map.items()
],
relationships=[],
Expand Down
Loading
Loading