diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index fd75a59adf46..003cdce78a75 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -5,6 +5,7 @@ __all__ = ["CooperativeConvFunction", "CooperativeConv"] + class CooperativeConvFunction(torch.autograd.Function): @staticmethod def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index aad5ea3f4812..5d5d44fd1eb7 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -203,8 +203,12 @@ def test_gpu_sampling_DataLoader( x = gb.CooperativeConvFunction.apply(subgraph, x) x, edge_index, size = subgraph.to_pyg(x) x = x[0] - one = torch.ones(edge_index.shape[1], dtype=x.dtype, device=x.device) - coo = torch.sparse_coo_tensor(edge_index.flipud(), one, size=(size[1], size[0])) + one = torch.ones( + edge_index.shape[1], dtype=x.dtype, device=x.device + ) + coo = torch.sparse_coo_tensor( + edge_index.flipud(), one, size=(size[1], size[0]) + ) x = torch.sparse.mm(coo, x) assert x.shape[0] == minibatch.seeds.shape[0] assert x.shape[1] == 1