diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 3d70786c6e7f..5f7060568bfc 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -23,7 +23,7 @@ class CooperativeConvFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): + def forward(ctx, subgraph: SampledSubgraph, input: torch.Tensor): """Implements the forward pass.""" counts_sent = subgraph._counts_sent counts_received = subgraph._counts_received @@ -32,10 +32,10 @@ def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): ctx.save_for_backward( counts_sent, counts_received, seed_inverse_ids, seed_sizes ) - out = h.new_empty((sum(counts_sent),) + h.shape[1:]) + out = input.new_empty((sum(counts_sent),) + input.shape[1:]) all_to_all( torch.split(out, counts_sent), - torch.split(h[seed_inverse_ids], counts_received), + torch.split(input[seed_inverse_ids], counts_received), ) return out