Skip to content

Commit

Permalink
Merge branch 'main' into add-cosmosdb-to-storage
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbradley1 authored Dec 19, 2024
2 parents cf165ac + c1c09ba commit b059367
Show file tree
Hide file tree
Showing 86 changed files with 434 additions and 691 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212190223784600.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241213181544864279.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Move extractor code to co-locate with operations."
}
5 changes: 0 additions & 5 deletions graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ class ClaimExtractionConfig(LLMConfig):

def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
)

return self.strategy or {
"type": ExtractClaimsStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt)
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/models/embed_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class EmbedGraphConfig(BaseModel):

def resolved_strategy(self) -> dict:
"""Get the resolved node2vec strategy."""
from graphrag.index.operations.embed_graph import (
from graphrag.index.operations.embed_graph.typing import (
EmbedGraphStrategyType,
)

Expand Down
14 changes: 1 addition & 13 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage


async def compute_communities(
def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
Expand All @@ -32,12 +28,4 @@ async def compute_communities(
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)

if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return base_communities
20 changes: 2 additions & 18 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage


async def create_base_text_units(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand Down Expand Up @@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)

output = cast(
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)

return output
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
12 changes: 6 additions & 6 deletions graphrag/index/flows/create_final_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.graph.extractors.community_reports.schemas import (
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
CLAIM_DESCRIPTION,
CLAIM_DETAILS,
CLAIM_ID,
Expand All @@ -32,11 +37,6 @@
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)


async def create_final_community_reports(
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/flows/create_final_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.extract_covariates import (
from graphrag.index.operations.extract_covariates.extract_covariates import (
extract_covariates,
)

Expand Down
23 changes: 14 additions & 9 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
VerbCallbacks,
)

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.layout_graph import layout_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph


def create_final_nodes(
Expand All @@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy,
embeddings=graph_embeddings,
)
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left"
)

joined = nodes.merge(base_communities, on="title", how="left")
joined["level"] = joined["level"].fillna(0).astype(int)
joined["community"] = joined["community"].fillna(-1).astype(int)
degrees = compute_degree(graph)

return joined.loc[
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
)
nodes["level"] = nodes["level"].fillna(0).astype(int)
nodes["community"] = nodes["community"].fillna(-1).astype(int)
# disconnected nodes and those with no community even at level 0 can be missing degree
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
:,
[
"id",
Expand Down
9 changes: 7 additions & 2 deletions graphrag/index/flows/create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@

import pandas as pd

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.create_graph import create_graph


def create_final_relationships(
base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
relationships = base_relationship_edges

graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)

relationships["combined_degree"] = compute_edge_combined_degree(
relationships,
base_entity_nodes,
degrees,
node_name_column="title",
node_degree_column="degree",
edge_source_column="source",
Expand Down
93 changes: 13 additions & 80 deletions graphrag/index/flows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,33 @@
from typing import Any
from uuid import uuid4

import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage


async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
Expand All @@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads,
)

if not _validate_data(entity_dfs):
if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)

if not _validate_data(relationship_dfs):
if not _validate_data(relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)

merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
entities,
relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)

graph = create_graph(base_relationship_edges)

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
base_relationship_edges = _prep_edges(relationships, relationship_summaries)

if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
base_entity_nodes = _prep_nodes(entities, entity_summaries)

return (base_entity_nodes, base_relationship_edges)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)


def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
def _prep_nodes(entities, summaries) -> pd.DataFrame:
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
subset="title"
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
Expand All @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])


def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
def _validate_data(df: pd.DataFrame) -> bool:
"""Validate that the dataframe has data."""
return len(df) > 0
3 changes: 1 addition & 2 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"],
)

data = data.loc[:, ["id", "embedding"]]

if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot(
data,
name=f"embeddings.{name}",
Expand Down
4 changes: 0 additions & 4 deletions graphrag/index/graph/__init__.py

This file was deleted.

8 changes: 0 additions & 8 deletions graphrag/index/graph/embedding/__init__.py

This file was deleted.

Loading

0 comments on commit b059367

Please sign in to comment.