Skip to content

Commit

Permalink
[CUDA] Make sanity check optional for dgl.create_block. (#7240)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Mar 26, 2024
1 parent 3c39153 commit 7815fe8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
17 changes: 14 additions & 3 deletions python/dgl/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,12 @@ def heterograph(data_dict, num_nodes_dict=None, idtype=None, device=None):


def create_block(
data_dict, num_src_nodes=None, num_dst_nodes=None, idtype=None, device=None
data_dict,
num_src_nodes=None,
num_dst_nodes=None,
idtype=None,
device=None,
node_count_check=True,
):
"""Create a message flow graph (MFG) as a :class:`DGLBlock` object.
Expand Down Expand Up @@ -456,6 +461,9 @@ def create_block(
the :attr:`data` argument. If :attr:`data` is not a tuple of node-tensors, the
returned graph is on CPU. If the specified :attr:`device` differs from that of the
provided tensors, it casts the given tensors to the specified device first.
node_count_check : bool, optional
When num_src_nodes and num_dst_nodes are passed, whether we should perform
sanity checks to ensure they are valid.
Returns
-------
Expand Down Expand Up @@ -540,13 +548,16 @@ def create_block(
node_tensor_dict = {}
for (sty, ety, dty), data in data_dict.items():
(sparse_fmt, arrays), urange, vrange = utils.graphdata2tensors(
data, idtype, bipartite=True
data,
idtype,
bipartite=True,
infer_node_count=need_infer or node_count_check,
)
node_tensor_dict[(sty, ety, dty)] = (sparse_fmt, arrays)
if need_infer:
num_src_nodes[sty] = max(num_src_nodes[sty], urange)
num_dst_nodes[dty] = max(num_dst_nodes[dty], vrange)
else: # sanity check
elif node_count_check: # sanity check
if num_src_nodes[sty] < urange:
raise DGLError(
"The given number of nodes of source node type {} must be larger"
Expand Down
1 change: 1 addition & 0 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def cast_to_minimum_dtype(v: CSCFormatBase):
sampled_csc,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
node_count_check=False,
)
)

Expand Down
20 changes: 15 additions & 5 deletions python/dgl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None):
SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"])


def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
def graphdata2tensors(
data, idtype=None, bipartite=False, infer_node_count=True, **kwargs
):
"""Function to convert various types of data to edge tensors and infer
the number of nodes.
Expand All @@ -137,6 +139,9 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
bipartite : bool, optional
Whether infer number of nodes of a bipartite graph --
num_src and num_dst can be different.
infer_node_count : bool, optional
Whether infer number of nodes at all. If False, num_src and num_dst
are returned as None.
kwargs
- edge_id_attr_name : The name (str) of the edge attribute that stores the edge
Expand Down Expand Up @@ -186,23 +191,28 @@ def graphdata2tensors(data, idtype=None, bipartite=False, **kwargs):
data.format, tuple(F.tensor(a) for a in data.arrays)
)

num_src, num_dst = None, None
if isinstance(data, SparseAdjTuple):
if idtype is not None:
data = SparseAdjTuple(
data.format, tuple(F.astype(a, idtype) for a in data.arrays)
)
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, list):
src, dst = elist2tensor(data, idtype)
data = SparseAdjTuple("coo", (src, dst))
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
elif isinstance(data, sp.sparse.spmatrix):
# We can get scipy matrix's number of rows and columns easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
data = scipy2tensor(data, idtype)
elif isinstance(data, nx.Graph):
# We can get networkx graph's number of sources and destinations easily.
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
if infer_node_count:
num_src, num_dst = infer_num_nodes(data, bipartite=bipartite)
edge_id_attr_name = kwargs.get("edge_id_attr_name", None)
if bipartite:
top_map = kwargs.get("top_map")
Expand Down

0 comments on commit 7815fe8

Please sign in to comment.