Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 13, 2024
1 parent 81a631b commit 13a2881
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class CooperativeConvFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor):
def forward(ctx, subgraph: SampledSubgraph, input: torch.Tensor):
"""Implements the forward pass."""
counts_sent = subgraph._counts_sent
counts_received = subgraph._counts_received
Expand All @@ -32,10 +32,10 @@ def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor):
ctx.save_for_backward(
counts_sent, counts_received, seed_inverse_ids, seed_sizes
)
out = h.new_empty((sum(counts_sent),) + h.shape[1:])
out = input.new_empty((sum(counts_sent),) + input.shape[1:])
all_to_all(
torch.split(out, counts_sent),
torch.split(h[seed_inverse_ids], counts_received),
torch.split(input[seed_inverse_ids], counts_received),
)
return out

Expand Down

0 comments on commit 13a2881

Please sign in to comment.