Skip to content

Commit

Permalink
Simplify layout_graph config
Browse files Browse the repository at this point in the history
  • Loading branch information
natoverse committed Dec 24, 2024
1 parent 2abd5b4 commit be3440f
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 76 deletions.
2 changes: 1 addition & 1 deletion graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
enabled: false # if true, will generate node2vec embeddings for nodes
umap:
enabled: false # if true, will generate UMAP embeddings for nodes
enabled: false # if true, will generate UMAP embeddings for nodes (embed_graph must also be enabled)
snapshots:
graphml: false
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
PipelineWorkflowReference(
name=create_final_nodes,
config={
"layout_graph_enabled": settings.umap.enabled,
"layout_enabled": settings.umap.enabled,
"embed_graph": settings.embed_graph,
},
),
Expand Down
6 changes: 2 additions & 4 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

"""All the steps to transform final nodes."""

from typing import Any

import pandas as pd
from datashaper import (
VerbCallbacks,
Expand All @@ -23,7 +21,7 @@ def create_final_nodes(
base_communities: pd.DataFrame,
callbacks: VerbCallbacks,
embed_config: EmbedGraphConfig,
layout_strategy: dict[str, Any],
layout_enabled: bool,
) -> pd.DataFrame:
"""All the steps to transform final nodes."""
graph = create_graph(base_relationship_edges)
Expand All @@ -36,7 +34,7 @@ def create_final_nodes(
layout = layout_graph(
graph,
callbacks,
layout_strategy,
layout_enabled,
embeddings=graph_embeddings,
)

Expand Down
66 changes: 20 additions & 46 deletions graphrag/index/operations/layout_graph/layout_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

"""A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition."""

from enum import Enum
from typing import Any

import networkx as nx
import pandas as pd
from datashaper import VerbCallbacks
Expand All @@ -14,21 +11,10 @@
from graphrag.index.operations.layout_graph.typing import GraphLayout


class LayoutGraphStrategyType(str, Enum):
"""LayoutGraphStrategyType class definition."""

umap = "umap"
zero = "zero"

def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'


def layout_graph(
graph: nx.Graph,
callbacks: VerbCallbacks,
strategy: dict[str, Any],
enabled: bool,
embeddings: NodeEmbeddings | None,
):
"""
Expand All @@ -54,14 +40,10 @@ def layout_graph(
min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75
```
"""
strategy_type = strategy.get("type", LayoutGraphStrategyType.umap)
strategy_args = {**strategy}

layout = _run_layout(
strategy_type,
graph,
enabled,
embeddings if embeddings is not None else {},
strategy_args,
callbacks,
)

Expand All @@ -73,34 +55,26 @@ def layout_graph(


def _run_layout(
strategy: LayoutGraphStrategyType,
graph: nx.Graph,
enabled: bool,
embeddings: NodeEmbeddings,
args: dict[str, Any],
callbacks: VerbCallbacks,
) -> GraphLayout:
match strategy:
case LayoutGraphStrategyType.umap:
from graphrag.index.operations.layout_graph.umap import (
run as run_umap,
)

return run_umap(
graph,
embeddings,
args,
lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d),
)
case LayoutGraphStrategyType.zero:
from graphrag.index.operations.layout_graph.zero import (
run as run_zero,
)
if enabled:
from graphrag.index.operations.layout_graph.umap import (
run as run_umap,
)

return run_umap(
graph,
embeddings,
lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d),
)
from graphrag.index.operations.layout_graph.zero import (
run as run_zero,
)

return run_zero(
graph,
args,
lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d),
)
case _:
msg = f"Unknown strategy {strategy}"
raise ValueError(msg)
return run_zero(
graph,
lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d),
)
6 changes: 1 addition & 5 deletions graphrag/index/operations/layout_graph/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import traceback
from typing import Any

import networkx as nx
import numpy as np
Expand All @@ -27,7 +26,6 @@
def run(
graph: nx.Graph,
embeddings: NodeEmbeddings,
args: dict[str, Any],
on_error: ErrorHandlerFn,
) -> GraphLayout:
"""Run method definition."""
Expand Down Expand Up @@ -56,8 +54,6 @@ def run(
embedding_vectors=np.array(embedding_vectors),
node_labels=nodes,
**additional_args,
min_dist=args.get("min_dist", 0.75),
n_neighbors=args.get("n_neighbors", 5),
)
except Exception as e:
log.exception("Error running UMAP")
Expand Down Expand Up @@ -87,7 +83,7 @@ def compute_umap_positions(
node_categories: list[int] | None = None,
node_sizes: list[int] | None = None,
min_dist: float = 0.75,
n_neighbors: int = 25,
n_neighbors: int = 5,
spread: int = 1,
metric: str = "euclidean",
n_components: int = 2,
Expand Down
2 changes: 0 additions & 2 deletions graphrag/index/operations/layout_graph/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import traceback
from typing import Any

import networkx as nx

Expand All @@ -24,7 +23,6 @@

def run(
graph: nx.Graph,
_args: dict[str, Any],
on_error: ErrorHandlerFn,
) -> GraphLayout:
"""Run method definition."""
Expand Down
22 changes: 6 additions & 16 deletions graphrag/index/workflows/v1/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""A module containing build_steps method definition."""

from typing import Any, cast
from typing import cast

from datashaper import (
Table,
Expand Down Expand Up @@ -31,23 +31,13 @@ def build_steps(
## Dependencies
* `workflow:extract_graph`
"""
layout_graph_enabled = config.get("layout_graph_enabled", True)
layout_graph_config = config.get(
"layout_graph",
{
"strategy": {
"type": "umap" if layout_graph_enabled else "zero",
},
},
)
layout_strategy = layout_graph_config.get("strategy")

embed_config = cast("EmbedGraphConfig", config.get("embed_graph"))
layout_enabled = config["layout_enabled"]
embed_config = cast("EmbedGraphConfig", config["embed_graph"])

return [
{
"verb": workflow_name,
"args": {"layout_strategy": layout_strategy, "embed_config": embed_config},
"args": {"layout_enabled": layout_enabled, "embed_config": embed_config},
"input": {
"source": "workflow:extract_graph",
"communities": "workflow:compute_communities",
Expand All @@ -61,7 +51,7 @@ async def workflow(
callbacks: VerbCallbacks,
runtime_storage: PipelineStorage,
embed_config: EmbedGraphConfig,
layout_strategy: dict[str, Any],
layout_enabled: bool,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform final nodes."""
Expand All @@ -75,7 +65,7 @@ async def workflow(
base_communities,
callbacks,
embed_config=embed_config,
layout_strategy=layout_strategy,
layout_enabled=layout_enabled,
)

return create_verb_result(
Expand Down
2 changes: 1 addition & 1 deletion tests/verbs/test_create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_create_final_nodes():
base_communities=base_communities,
callbacks=NoopVerbCallbacks(),
embed_config=embed_config,
layout_strategy={"type": "zero"},
layout_enabled=False,
)

assert "id" in expected.columns
Expand Down

0 comments on commit be3440f

Please sign in to comment.