From c006da2fdd5a80637aca0c1bd1f7921ef954cce1 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Thu, 13 Feb 2025 21:06:25 -0600 Subject: [PATCH] Dispatch for e.g. `nx.Graph(backend="cugraph")` --- .pre-commit-config.yaml | 5 ----- _nx_cugraph/__init__.py | 6 ++++++ nx_cugraph/classes/digraph.py | 22 +++++++++++++++++++--- nx_cugraph/classes/graph.py | 14 +++++++++++++- nx_cugraph/classes/multidigraph.py | 21 +++++++++++++++++++-- nx_cugraph/classes/multigraph.py | 16 ++++++++++++++-- scripts/update_readme.py | 5 ++++- 7 files changed, 75 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 330c58ae2..040322e34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,11 +65,6 @@ repos: - flake8==7.1.1 - flake8-bugbear==24.12.12 - flake8-simplify==0.21.0 - - repo: https://github.com/asottile/yesqa - rev: v1.5.0 - hooks: - - id: yesqa - additional_dependencies: *flake8_dependencies - repo: https://github.com/codespell-project/codespell rev: v2.4.1 hooks: diff --git a/_nx_cugraph/__init__.py b/_nx_cugraph/__init__.py index c83b489f2..7f3e0824d 100644 --- a/_nx_cugraph/__init__.py +++ b/_nx_cugraph/__init__.py @@ -81,6 +81,7 @@ "descendants", "descendants_at_distance", "diamond_graph", + "digraph__new__", "dijkstra_path", "dijkstra_path_length", "dodecahedral_graph", @@ -94,6 +95,7 @@ "from_scipy_sparse_array", "frucht_graph", "generic_bfs_edges", + "graph__new__", "has_path", "heawood_graph", "hits", @@ -122,6 +124,8 @@ "louvain_communities", "lowest_common_ancestor", "moebius_kantor_graph", + "multidigraph__new__", + "multigraph__new__", "node_connected_component", "null_graph", "number_connected_components", @@ -340,6 +344,8 @@ def update_env_var(varname): update_env_var("NETWORKX_AUTOMATIC_BACKENDS") # For NetworkX 3.2 # Automatically create nx-cugraph Graph from graph generators update_env_var("NETWORKX_BACKEND_PRIORITY_GENERATORS") + # And for graph classes such as `nx.Graph()` for NetworkX >=3.5 + update_env_var("NETWORKX_BACKEND_PRIORITY_CLASSES") # Run default NetworkX implementation (in >=3.4) if not implemented by nx-cugraph if (varname := "NETWORKX_FALLBACK_TO_NX") not in os.environ: os.environ[varname] = "true" diff --git a/nx_cugraph/classes/digraph.py b/nx_cugraph/classes/digraph.py index a73ca8d0b..24916d7e7 100644 --- a/nx_cugraph/classes/digraph.py +++ b/nx_cugraph/classes/digraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -25,8 +25,8 @@ import nx_cugraph as nxcg -from ..utils import index_dtype -from .graph import CudaGraph, Graph +from ..utils import index_dtype, networkx_algorithm +from .graph import CudaGraph, Graph, _GraphCache if TYPE_CHECKING: # pragma: no cover from nx_cugraph.typing import AttrKey @@ -106,6 +106,22 @@ def to_cudagraph_class(cls) -> type[CudaDiGraph]: def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph + @networkx_algorithm(name="digraph__new__", version_added="25.04") + def __new__(cls, incoming_graph_data=None, **attr): + return object.__new__(DiGraph) + + @__new__._can_run + def _(cls, incoming_graph_data=None, **attr): # noqa: N805 + if cls not in {nx.DiGraph, DiGraph}: + return "Unknown subclasses of nx.DiGraph are not supported." + return True + + del _ + + def __init__(self, incoming_graph_data=None, **attr): + super().__init__(incoming_graph_data, **attr) + self.__networkx_cache__ = _GraphCache(self) + ########################## # Networkx graph methods # ########################## diff --git a/nx_cugraph/classes/graph.py b/nx_cugraph/classes/graph.py index f6da5bbce..6a01668b0 100644 --- a/nx_cugraph/classes/graph.py +++ b/nx_cugraph/classes/graph.py @@ -28,7 +28,7 @@ import nx_cugraph as nxcg from nx_cugraph import _nxver -from ..utils import index_dtype +from ..utils import index_dtype, networkx_algorithm if TYPE_CHECKING: # pragma: no cover from collections.abc import Iterable, Iterator @@ -291,6 +291,18 @@ def to_networkx_class(cls) -> type[nx.Graph]: def to_undirected_class(cls) -> type[Graph]: return Graph + @networkx_algorithm(name="graph__new__", version_added="25.04") + def __new__(cls, incoming_graph_data=None, **attr): + return object.__new__(Graph) + + @__new__._can_run + def _(cls, incoming_graph_data=None, **attr): # noqa: N805 + if cls not in {nx.Graph, Graph}: + return "Unknown subclasses of nx.Graph are not supported." + return True + + del _ + def __init__(self, incoming_graph_data=None, **attr): super().__init__(incoming_graph_data, **attr) self.__networkx_cache__ = _GraphCache(self) diff --git a/nx_cugraph/classes/multidigraph.py b/nx_cugraph/classes/multidigraph.py index 0671d21a8..644f4b90f 100644 --- a/nx_cugraph/classes/multidigraph.py +++ b/nx_cugraph/classes/multidigraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -16,8 +16,9 @@ import nx_cugraph as nxcg +from ..utils import networkx_algorithm from .digraph import CudaDiGraph, DiGraph -from .graph import Graph +from .graph import Graph, _GraphCache from .multigraph import CudaMultiGraph, MultiGraph __all__ = ["CudaMultiDiGraph", "MultiDiGraph"] @@ -51,6 +52,22 @@ def to_cudagraph_class(cls) -> type[CudaMultiDiGraph]: def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + @networkx_algorithm(name="multidigraph__new__", version_added="25.04") + def __new__(cls, incoming_graph_data=None, multigraph_input=None, **attr): + return object.__new__(MultiDiGraph) + + @__new__._can_run + def _(cls, incoming_graph_data=None, multigraph_input=None, **attr): # noqa: N805 + if cls not in {nx.MultiDiGraph, MultiDiGraph}: + return "Unknown subclasses of nx.MultiDiGraph are not supported." + return True + + del _ + + def __init__(self, incoming_graph_data=None, multigraph_input=None, **attr): + super().__init__(incoming_graph_data, multigraph_input, **attr) + self.__networkx_cache__ = _GraphCache(self) + ########################## # Networkx graph methods # ########################## diff --git a/nx_cugraph/classes/multigraph.py b/nx_cugraph/classes/multigraph.py index 3f0204f69..da0710281 100644 --- a/nx_cugraph/classes/multigraph.py +++ b/nx_cugraph/classes/multigraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,7 +21,7 @@ import nx_cugraph as nxcg -from ..utils import index_dtype +from ..utils import index_dtype, networkx_algorithm from .graph import CudaGraph, Graph, _GraphCache if TYPE_CHECKING: @@ -73,6 +73,18 @@ def to_networkx_class(cls) -> type[nx.MultiGraph]: def to_undirected_class(cls) -> type[MultiGraph]: return MultiGraph + @networkx_algorithm(name="multigraph__new__", version_added="25.04") + def __new__(cls, incoming_graph_data=None, multigraph_input=None, **attr): + return object.__new__(MultiGraph) + + @__new__._can_run + def _(cls, incoming_graph_data=None, multigraph_input=None, **attr): # noqa: N805 + if cls not in {nx.MultiGraph, MultiGraph}: + return "Unknown subclasses of nx.MultiGraph are not supported." + return True + + del _ + def __init__(self, incoming_graph_data=None, multigraph_input=None, **attr): super().__init__(incoming_graph_data, multigraph_input, **attr) self.__networkx_cache__ = _GraphCache(self) diff --git a/scripts/update_readme.py b/scripts/update_readme.py index 0dad5d675..26face0d9 100755 --- a/scripts/update_readme.py +++ b/scripts/update_readme.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -110,6 +110,9 @@ def main(readme_file, objects_filename): def get_payload(info, **kwargs): path = "networkx." + info.networkx_path subpath, name = path.rsplit(".", 1) + if "__" in name: + # Don't include e.g. Graph.__new__ + return None # Many objects are referred to in modules above where they are defined. while subpath: path = f"{subpath}.{name}"