Skip to content

Commit

Permalink
Fix (Multi)DiGraph.__networkx_cache__ initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Feb 14, 2025
1 parent 7c9c868 commit a2c495c
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
8 changes: 6 additions & 2 deletions nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -26,7 +26,7 @@
import nx_cugraph as nxcg

from ..utils import index_dtype
from .graph import CudaGraph, Graph
from .graph import CudaGraph, Graph, _GraphCache

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey
Expand Down Expand Up @@ -106,6 +106,10 @@ def to_cudagraph_class(cls) -> type[CudaDiGraph]:
def to_networkx_class(cls) -> type[nx.DiGraph]:
return nx.DiGraph

def __init__(self, incoming_graph_data=None, **attr):
super().__init__(incoming_graph_data, **attr)
self.__networkx_cache__ = _GraphCache(self)

##########################
# Networkx graph methods #
##########################
Expand Down
8 changes: 6 additions & 2 deletions nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Copyright (c) 2023-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -17,7 +17,7 @@
import nx_cugraph as nxcg

from .digraph import CudaDiGraph, DiGraph
from .graph import Graph
from .graph import Graph, _GraphCache
from .multigraph import CudaMultiGraph, MultiGraph

__all__ = ["CudaMultiDiGraph", "MultiDiGraph"]
Expand Down Expand Up @@ -51,6 +51,10 @@ def to_cudagraph_class(cls) -> type[CudaMultiDiGraph]:
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
return nx.MultiDiGraph

def __init__(self, incoming_graph_data=None, multigraph_input=None, **attr):
super().__init__(incoming_graph_data, multigraph_input, **attr)
self.__networkx_cache__ = _GraphCache(self)

##########################
# Networkx graph methods #
##########################
Expand Down
5 changes: 3 additions & 2 deletions nx_cugraph/scripts/print_tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -143,7 +143,8 @@ def create_tree(
incomplete=incomplete,
different=different,
)
assoc_in(tree, path.split("."), payload)
if payload is not None:
assoc_in(tree, path.split("."), payload)
return tree


Expand Down
13 changes: 12 additions & 1 deletion nx_cugraph/tests/test_classes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -10,7 +10,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import nx_cugraph as nxcg
from nx_cugraph.classes.graph import _GraphCache


def test_class_to_class():
Expand Down Expand Up @@ -75,3 +78,11 @@ def test_class_to_class():
assert val.to_cudagraph_class() is cls
assert cls.is_directed() == G.is_directed() == val.is_directed()
assert cls.is_multigraph() == G.is_multigraph() == val.is_multigraph()


@pytest.mark.parametrize(
"graph_class", [nxcg.Graph, nxcg.DiGraph, nxcg.MultiGraph, nxcg.MultiDiGraph]
)
def test_cache_type(graph_class):
G = graph_class()
assert isinstance(G.__networkx_cache__, _GraphCache)
6 changes: 4 additions & 2 deletions scripts/update_readme.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION.
# Copyright (c) 2024-2025, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -111,10 +111,12 @@ def get_payload(info, **kwargs):
path = "networkx." + info.networkx_path
subpath, name = path.rsplit(".", 1)
# Many objects are referred to in modules above where they are defined.
while subpath:
while True:
path = f"{subpath}.{name}"
if path in doc_urls:
return f'<a href="{doc_urls[path]}">{name}</a>'
if subpath == "networkx":
break
subpath = subpath.rsplit(".", 1)[0]
warn(f"Unable to find URL for {name!r}: {path}", stacklevel=0)
return name
Expand Down

0 comments on commit a2c495c

Please sign in to comment.