diff --git a/nx_cugraph/classes/digraph.py b/nx_cugraph/classes/digraph.py index a73ca8d0b..9ac75a087 100644 --- a/nx_cugraph/classes/digraph.py +++ b/nx_cugraph/classes/digraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -26,7 +26,7 @@ import nx_cugraph as nxcg from ..utils import index_dtype -from .graph import CudaGraph, Graph +from .graph import CudaGraph, Graph, _GraphCache if TYPE_CHECKING: # pragma: no cover from nx_cugraph.typing import AttrKey @@ -106,6 +106,10 @@ def to_cudagraph_class(cls) -> type[CudaDiGraph]: def to_networkx_class(cls) -> type[nx.DiGraph]: return nx.DiGraph + def __init__(self, incoming_graph_data=None, **attr): + super().__init__(incoming_graph_data, **attr) + self.__networkx_cache__ = _GraphCache(self) + ########################## # Networkx graph methods # ########################## diff --git a/nx_cugraph/classes/multidigraph.py b/nx_cugraph/classes/multidigraph.py index 0671d21a8..3ec99b352 100644 --- a/nx_cugraph/classes/multidigraph.py +++ b/nx_cugraph/classes/multidigraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Copyright (c) 2023-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,7 +17,7 @@ import nx_cugraph as nxcg from .digraph import CudaDiGraph, DiGraph -from .graph import Graph +from .graph import Graph, _GraphCache from .multigraph import CudaMultiGraph, MultiGraph __all__ = ["CudaMultiDiGraph", "MultiDiGraph"] @@ -51,6 +51,10 @@ def to_cudagraph_class(cls) -> type[CudaMultiDiGraph]: def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + def __init__(self, incoming_graph_data=None, multigraph_input=None, **attr): + super().__init__(incoming_graph_data, multigraph_input, **attr) + self.__networkx_cache__ = _GraphCache(self) + ########################## # Networkx graph methods # ########################## diff --git a/nx_cugraph/scripts/print_tree.py b/nx_cugraph/scripts/print_tree.py index fbb1c3dd0..f93702e24 100755 --- a/nx_cugraph/scripts/print_tree.py +++ b/nx_cugraph/scripts/print_tree.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -143,7 +143,8 @@ def create_tree( incomplete=incomplete, different=different, ) - assoc_in(tree, path.split("."), payload) + if payload is not None: + assoc_in(tree, path.split("."), payload) return tree diff --git a/nx_cugraph/tests/test_classes.py b/nx_cugraph/tests/test_classes.py index 0ac238b35..fea41ac5e 100644 --- a/nx_cugraph/tests/test_classes.py +++ b/nx_cugraph/tests/test_classes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,7 +10,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import nx_cugraph as nxcg +from nx_cugraph.classes.graph import _GraphCache def test_class_to_class(): @@ -75,3 +78,11 @@ def test_class_to_class(): assert val.to_cudagraph_class() is cls assert cls.is_directed() == G.is_directed() == val.is_directed() assert cls.is_multigraph() == G.is_multigraph() == val.is_multigraph() + + +@pytest.mark.parametrize( + "graph_class", [nxcg.Graph, nxcg.DiGraph, nxcg.MultiGraph, nxcg.MultiDiGraph] +) +def test_cache_type(graph_class): + G = graph_class() + assert isinstance(G.__networkx_cache__, _GraphCache) diff --git a/scripts/update_readme.py b/scripts/update_readme.py index 0dad5d675..2b04c7b08 100755 --- a/scripts/update_readme.py +++ b/scripts/update_readme.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -111,10 +111,12 @@ def get_payload(info, **kwargs): path = "networkx." + info.networkx_path subpath, name = path.rsplit(".", 1) # Many objects are referred to in modules above where they are defined. - while subpath: + while True: path = f"{subpath}.{name}" if path in doc_urls: return f'{name}' + if subpath == "networkx": + break subpath = subpath.rsplit(".", 1)[0] warn(f"Unable to find URL for {name!r}: {path}", stacklevel=0) return name