Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 13, 2024
1 parent e4b7925 commit 02f49d2
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Graphbolt cooperative convolution."""
from typing import Dict, Union

import torch

from ..sampled_subgraph import SampledSubgraph
Expand All @@ -23,7 +25,11 @@ class CooperativeConvFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor):
def forward(
ctx,
subgraph: SampledSubgraph,
tensor: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
counts_sent = convert_to_hetero(subgraph._counts_sent)
counts_received = convert_to_hetero(subgraph._counts_received)
Expand All @@ -33,21 +39,24 @@ def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor):
counts_sent, counts_received, seed_inverse_ids, seed_sizes
)
outs = {}
for ntype, tensor in convert_to_hetero(tensor).items():
out = tensor.new_empty(
(sum(counts_sent[ntype]),) + tensor.shape[1:]
for ntype, typed_tensor in convert_to_hetero(tensor).items():
out = typed_tensor.new_empty(
(sum(counts_sent[ntype]),) + typed_tensor.shape[1:]
)
all_to_all(
torch.split(out, counts_sent[ntype]),
torch.split(
tensor[seed_inverse_ids[ntype]], counts_received[ntype]
typed_tensor[seed_inverse_ids[ntype]],
counts_received[ntype],
),
)
outs[ntype] = out
return revert_to_homo(out)

@staticmethod
def backward(ctx, grad_output):
def backward(
ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]]
):
"""Implements the forward pass."""
(
counts_sent,
Expand All @@ -56,16 +65,18 @@ def backward(ctx, grad_output):
seed_sizes,
) = ctx.saved_tensors
outs = {}
for ntype, grad_output in convert_to_hetero(grad_output).items():
out = grad_output.new_empty(
(sum(counts_received[ntype]),) + grad_output.shape[1:]
for ntype, typed_grad_output in convert_to_hetero(grad_output).items():
out = typed_grad_output.new_empty(
(sum(counts_received[ntype]),) + typed_grad_output.shape[1:]
)
all_to_all(
torch.split(out, counts_received[ntype]),
torch.split(grad_output, counts_sent[ntype]),
torch.split(typed_grad_output, counts_sent[ntype]),
)
i = out.new_empty(2, out.shape[0], dtype=torch.int64)
i[0] = torch.arange(out.shape[0], device=grad_output.device) # src
i[0] = torch.arange(
out.shape[0], device=typed_grad_output.device
) # src
i[1] = seed_inverse_ids[ntype] # dst
coo = torch.sparse_coo_tensor(
i, 1, size=(seed_sizes[ntype], i.shape[1])
Expand All @@ -89,6 +100,10 @@ class CooperativeConv(torch.nn.Module):
GPUs at the expense of communication.
"""

def forward(self, subgraph: SampledSubgraph, x: torch.Tensor):
def forward(
self,
subgraph: SampledSubgraph,
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
):
"""Implements the forward pass."""
return CooperativeConvFunction.apply(subgraph, x)

0 comments on commit 02f49d2

Please sign in to comment.