Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 13, 2024
1 parent 7730aea commit 87ea35c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

__all__ = ["CooperativeConvFunction", "CooperativeConv"]


class CooperativeConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor):
Expand Down
8 changes: 6 additions & 2 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,12 @@ def test_gpu_sampling_DataLoader(
x = gb.CooperativeConvFunction.apply(subgraph, x)
x, edge_index, size = subgraph.to_pyg(x)
x = x[0]
one = torch.ones(edge_index.shape[1], dtype=x.dtype, device=x.device)
coo = torch.sparse_coo_tensor(edge_index.flipud(), one, size=(size[1], size[0]))
one = torch.ones(
edge_index.shape[1], dtype=x.dtype, device=x.device
)
coo = torch.sparse_coo_tensor(
edge_index.flipud(), one, size=(size[1], size[0])
)
x = torch.sparse.mm(coo, x)
assert x.shape[0] == minibatch.seeds.shape[0]
assert x.shape[1] == 1
Expand Down

0 comments on commit 87ea35c

Please sign in to comment.