From 6a02b52113b8736b75f99cef47c567f060268be6 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 12 Sep 2024 23:51:00 +0000 Subject: [PATCH 01/11] [GraphBolt][CUDA] Add `CooperativeConv`. --- python/dgl/graphbolt/impl/cooperative_conv.py | 35 +++++++++++++++++++ python/dgl/graphbolt/impl/neighbor_sampler.py | 30 ++++++++++++++-- python/dgl/graphbolt/subgraph_sampler.py | 5 +++ 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 python/dgl/graphbolt/impl/cooperative_conv.py diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py new file mode 100644 index 000000000000..d977cffc15de --- /dev/null +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -0,0 +1,35 @@ +from ..sampled_subgraph import SampledSubgraph +from ..subgraph_sampler import all_to_all + +import torch + +class CooperativeConvFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, subgraph: SampledSubgraph, h): + counts_sent = subgraph._counts_sent + counts_received = subgraph._counts_received + seed_inverse_ids = subgraph._seed_inverse_ids + seed_sizes = subgraph._seed_sizes + ctx.save_for_backward(counts_sent, counts_received, seed_inverse_ids, seed_sizes) + out = h.new_empty((sum(counts_sent),) + h.shape[1:]) + all_to_all(torch.split(out, counts_sent), torch.split(h[seed_inverse_ids], counts_received)) + return out + + @staticmethod + def backward(ctx, grad_output): + counts_sent, counts_received, seed_inverse_ids, seed_sizes = ctx.saved_tensors + out = grad_output.new_empty((sum(counts_received),) + grad_output.shape[1:]) + all_to_all(torch.split(out, counts_received), torch.split(grad_output, counts_sent)) + i = out.new_empty(2, out.shape[0], dtype=torch.int64) + i[0] = torch.arange(out.shape[0], device=grad_output.device) # src + i[1] = seed_inverse_ids # dst + coo = torch.sparse_coo_tensor(i, 1, size=(seed_sizes, i.shape[1])) + rout = torch.sparse.mm(coo, out) + return None, rout + +class CooperativeConv(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, subgraph: SampledSubgraph, x): + return CooperativeConvFunction.apply(subgraph, x) diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 352dedc067f2..0e1945430d75 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -601,17 +601,16 @@ def _seeds_cooperative_exchange_2(minibatch): typed_seeds.split(typed_counts_sent), ) seeds_received[ntype] = typed_seeds_received - subgraph._seeds_received = seeds_received + minibatch._seed_nodes = seeds_received subgraph._counts_sent = revert_to_homo(counts_sent) subgraph._counts_received = revert_to_homo(counts_received) return minibatch @staticmethod def _seeds_cooperative_exchange_3(minibatch): - subgraph = minibatch.sampled_subgraphs[0] nodes = { ntype: [typed_seeds] - for ntype, typed_seeds in subgraph._seeds_received.items() + for ntype, typed_seeds in minibatch._seed_nodes.items() } minibatch._unique_future = unique_and_compact( nodes, 0, 1, async_op=True @@ -627,6 +626,11 @@ def _seeds_cooperative_exchange_4(minibatch): } minibatch._seed_nodes = revert_to_homo(unique_seeds) subgraph = minibatch.sampled_subgraphs[0] + sizes = { + ntype: typed_seeds.size(0) + for ntype, typed_seeds in unique_seeds.items() + } + subgraph._seed_sizes = revert_to_homo(sizes) subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds) return minibatch @@ -831,6 +835,16 @@ class NeighborSampler(NeighborSamplerImpl): gpu_cache_threshold : int, optional Determines how many times a vertex needs to be accessed before its neighborhood ends up being cached on the GPU. + cooperative: bool, optional + Boolean indicating whether Cooperative Minibatching, which was initially + proposed in + `Deep Graph Library PR#4337`__ + and was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__. Cooperation between the GPUs + eliminates duplicate work performed across the GPUs due to the + overlapping sampled k-hop neighborhoods of seed nodes when performing + GNN minibatching. asynchronous: bool Boolean indicating whether sampling and compaction stages should run in background threads to hide the latency of CPU GPU synchronization. @@ -986,6 +1000,16 @@ class LayerNeighborSampler(NeighborSamplerImpl): gpu_cache_threshold : int, optional Determines how many times a vertex needs to be accessed before its neighborhood ends up being cached on the GPU. + cooperative: bool, optional + Boolean indicating whether Cooperative Minibatching, which was initially + proposed in + `Deep Graph Library PR#4337`__ + and was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__. Cooperation between the GPUs + eliminates duplicate work performed across the GPUs due to the + overlapping sampled k-hop neighborhoods of seed nodes when performing + GNN minibatching. asynchronous: bool Boolean indicating whether sampling and compaction stages should run in background threads to hide the latency of CPU GPU synchronization. diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 556950982fb7..f4683a5b05dc 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -275,6 +275,11 @@ def _seeds_cooperative_exchange_4(minibatch): ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items() } minibatch._seed_nodes = revert_to_homo(unique_seeds) + sizes = { + ntype: typed_seeds.size(0) + for ntype, typed_seeds in unique_seeds.items() + } + minibatch._seed_sizes = revert_to_homo(sizes) minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds) return minibatch From 486dfc286aa14699ca77ef3e741d4e7257eb21a6 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Thu, 12 Sep 2024 23:59:40 +0000 Subject: [PATCH 02/11] add documentation. --- python/dgl/graphbolt/impl/cooperative_conv.py | 50 +++++++++++++++---- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index d977cffc15de..e98a69833ee0 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -1,35 +1,65 @@ +import torch + from ..sampled_subgraph import SampledSubgraph from ..subgraph_sampler import all_to_all -import torch class CooperativeConvFunction(torch.autograd.Function): @staticmethod - def forward(ctx, subgraph: SampledSubgraph, h): + def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): counts_sent = subgraph._counts_sent counts_received = subgraph._counts_received seed_inverse_ids = subgraph._seed_inverse_ids seed_sizes = subgraph._seed_sizes - ctx.save_for_backward(counts_sent, counts_received, seed_inverse_ids, seed_sizes) + ctx.save_for_backward( + counts_sent, counts_received, seed_inverse_ids, seed_sizes + ) out = h.new_empty((sum(counts_sent),) + h.shape[1:]) - all_to_all(torch.split(out, counts_sent), torch.split(h[seed_inverse_ids], counts_received)) + all_to_all( + torch.split(out, counts_sent), + torch.split(h[seed_inverse_ids], counts_received), + ) return out @staticmethod def backward(ctx, grad_output): - counts_sent, counts_received, seed_inverse_ids, seed_sizes = ctx.saved_tensors - out = grad_output.new_empty((sum(counts_received),) + grad_output.shape[1:]) - all_to_all(torch.split(out, counts_received), torch.split(grad_output, counts_sent)) + ( + counts_sent, + counts_received, + seed_inverse_ids, + seed_sizes, + ) = ctx.saved_tensors + out = grad_output.new_empty( + (sum(counts_received),) + grad_output.shape[1:] + ) + all_to_all( + torch.split(out, counts_received), + torch.split(grad_output, counts_sent), + ) i = out.new_empty(2, out.shape[0], dtype=torch.int64) i[0] = torch.arange(out.shape[0], device=grad_output.device) # src - i[1] = seed_inverse_ids # dst + i[1] = seed_inverse_ids # dst coo = torch.sparse_coo_tensor(i, 1, size=(seed_sizes, i.shape[1])) rout = torch.sparse.mm(coo, out) return None, rout + class CooperativeConv(torch.nn.Module): + """Cooperative convolution operation from Cooperative Minibatching. + + Implements the `all-to-all` message passing algorithm + in Cooperative Minibatching, which was initially proposed in + `Deep Graph Library PR#4337`__ and + was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__. + Cooperation between the GPUs eliminates duplicate work performed across the + GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when + performing GNN minibatching. This reduces the redundant computations across + GPUs at the expense of communication. + """ def __init__(self): super().__init__() - - def forward(self, subgraph: SampledSubgraph, x): + + def forward(self, subgraph: SampledSubgraph, x: torch.Tensor): return CooperativeConvFunction.apply(subgraph, x) From c933aec4cadfa6136603e29dc9ec9e642931d352 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 00:00:55 +0000 Subject: [PATCH 03/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index e98a69833ee0..543c011aa295 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -46,7 +46,7 @@ def backward(ctx, grad_output): class CooperativeConv(torch.nn.Module): """Cooperative convolution operation from Cooperative Minibatching. - + Implements the `all-to-all` message passing algorithm in Cooperative Minibatching, which was initially proposed in `Deep Graph Library PR#4337`__ and @@ -58,6 +58,7 @@ class CooperativeConv(torch.nn.Module): performing GNN minibatching. This reduces the redundant computations across GPUs at the expense of communication. """ + def __init__(self): super().__init__() From 7730aea5bcc036e67dbac2d9a49614ba358c15da Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:31:50 +0000 Subject: [PATCH 04/11] fix the test. --- python/dgl/graphbolt/impl/__init__.py | 1 + python/dgl/graphbolt/impl/cooperative_conv.py | 1 + python/dgl/graphbolt/impl/neighbor_sampler.py | 2 ++ python/dgl/graphbolt/subgraph_sampler.py | 2 ++ tests/python/pytorch/graphbolt/test_dataloader.py | 15 +++++++++++++++ 5 files changed, 21 insertions(+) diff --git a/python/dgl/graphbolt/impl/__init__.py b/python/dgl/graphbolt/impl/__init__.py index 19fef44e462c..f4e53327c3ae 100644 --- a/python/dgl/graphbolt/impl/__init__.py +++ b/python/dgl/graphbolt/impl/__init__.py @@ -15,3 +15,4 @@ from .gpu_graph_cache import * from .cpu_feature_cache import * from .cpu_cached_feature import * +from .cooperative_conv import * diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 543c011aa295..fd75a59adf46 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -3,6 +3,7 @@ from ..sampled_subgraph import SampledSubgraph from ..subgraph_sampler import all_to_all +__all__ = ["CooperativeConvFunction", "CooperativeConv"] class CooperativeConvFunction(torch.autograd.Function): @staticmethod diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 0e1945430d75..7ddba6d7ccac 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -601,6 +601,8 @@ def _seeds_cooperative_exchange_2(minibatch): typed_seeds.split(typed_counts_sent), ) seeds_received[ntype] = typed_seeds_received + counts_sent[ntype] = typed_counts_sent + counts_received[ntype] = typed_counts_received minibatch._seed_nodes = seeds_received subgraph._counts_sent = revert_to_homo(counts_sent) subgraph._counts_received = revert_to_homo(counts_received) diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index f4683a5b05dc..01ba666465ad 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -251,6 +251,8 @@ def _seeds_cooperative_exchange_2(minibatch, group=None): group, ) seeds_received[ntype] = typed_seeds_received + counts_sent[ntype] = typed_counts_sent + counts_received[ntype] = typed_counts_received minibatch._seed_nodes = seeds_received minibatch._counts_sent = revert_to_homo(counts_sent) minibatch._counts_received = revert_to_homo(counts_received) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index b02c820dd60d..aad5ea3f4812 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -6,6 +6,7 @@ import dgl import dgl.graphbolt +import dgl.graphbolt as gb import pytest import torch import torch.distributed as thd @@ -194,5 +195,19 @@ def test_gpu_sampling_DataLoader( if sampler_name == "LayerNeighborSampler": assert torch.equal(edge_feature, edge_feature_ref) assert len(list(dataloader)) == N // B + + if asynchronous and cooperative: + for minibatch in minibatches: + x = torch.ones((minibatch.node_ids().size(0), 1), device=F.ctx()) + for subgraph in minibatch.sampled_subgraphs: + x = gb.CooperativeConvFunction.apply(subgraph, x) + x, edge_index, size = subgraph.to_pyg(x) + x = x[0] + one = torch.ones(edge_index.shape[1], dtype=x.dtype, device=x.device) + coo = torch.sparse_coo_tensor(edge_index.flipud(), one, size=(size[1], size[0])) + x = torch.sparse.mm(coo, x) + assert x.shape[0] == minibatch.seeds.shape[0] + assert x.shape[1] == 1 + if thd.is_initialized(): thd.destroy_process_group() From 87ea35ceceb752bda56b38544c134030d16d3015 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:33:36 +0000 Subject: [PATCH 05/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 1 + tests/python/pytorch/graphbolt/test_dataloader.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index fd75a59adf46..003cdce78a75 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -5,6 +5,7 @@ __all__ = ["CooperativeConvFunction", "CooperativeConv"] + class CooperativeConvFunction(torch.autograd.Function): @staticmethod def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index aad5ea3f4812..5d5d44fd1eb7 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -203,8 +203,12 @@ def test_gpu_sampling_DataLoader( x = gb.CooperativeConvFunction.apply(subgraph, x) x, edge_index, size = subgraph.to_pyg(x) x = x[0] - one = torch.ones(edge_index.shape[1], dtype=x.dtype, device=x.device) - coo = torch.sparse_coo_tensor(edge_index.flipud(), one, size=(size[1], size[0])) + one = torch.ones( + edge_index.shape[1], dtype=x.dtype, device=x.device + ) + coo = torch.sparse_coo_tensor( + edge_index.flipud(), one, size=(size[1], size[0]) + ) x = torch.sparse.mm(coo, x) assert x.shape[0] == minibatch.seeds.shape[0] assert x.shape[1] == 1 From e087be986f7560a8b689edcf800f407284f23c80 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:37:16 +0000 Subject: [PATCH 06/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 003cdce78a75..91300e6a7565 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -1,3 +1,4 @@ +"""Graphbolt cooperative convolution.""" import torch from ..sampled_subgraph import SampledSubgraph @@ -7,6 +8,20 @@ class CooperativeConvFunction(torch.autograd.Function): + """Cooperative convolution operation from Cooperative Minibatching. + + Implements the `all-to-all` message passing algorithm + in Cooperative Minibatching, which was initially proposed in + `Deep Graph Library PR#4337`__ and + was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__. + Cooperation between the GPUs eliminates duplicate work performed across the + GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when + performing GNN minibatching. This reduces the redundant computations across + GPUs at the expense of communication. + """ + @staticmethod def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): counts_sent = subgraph._counts_sent From 81a631b110340e8501008bdf87262b3ff9766d48 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:49:12 +0000 Subject: [PATCH 07/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 91300e6a7565..3d70786c6e7f 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -24,6 +24,7 @@ class CooperativeConvFunction(torch.autograd.Function): @staticmethod def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): + """Implements the forward pass.""" counts_sent = subgraph._counts_sent counts_received = subgraph._counts_received seed_inverse_ids = subgraph._seed_inverse_ids @@ -40,6 +41,7 @@ def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): @staticmethod def backward(ctx, grad_output): + """Implements the forward pass.""" ( counts_sent, counts_received, @@ -76,8 +78,6 @@ class CooperativeConv(torch.nn.Module): GPUs at the expense of communication. """ - def __init__(self): - super().__init__() - def forward(self, subgraph: SampledSubgraph, x: torch.Tensor): + """Implements the forward pass.""" return CooperativeConvFunction.apply(subgraph, x) From 13a288179d4a76b62b784daddc3214463a7bc076 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:52:27 +0000 Subject: [PATCH 08/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 3d70786c6e7f..5f7060568bfc 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -23,7 +23,7 @@ class CooperativeConvFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): + def forward(ctx, subgraph: SampledSubgraph, input: torch.Tensor): """Implements the forward pass.""" counts_sent = subgraph._counts_sent counts_received = subgraph._counts_received @@ -32,10 +32,10 @@ def forward(ctx, subgraph: SampledSubgraph, h: torch.Tensor): ctx.save_for_backward( counts_sent, counts_received, seed_inverse_ids, seed_sizes ) - out = h.new_empty((sum(counts_sent),) + h.shape[1:]) + out = input.new_empty((sum(counts_sent),) + input.shape[1:]) all_to_all( torch.split(out, counts_sent), - torch.split(h[seed_inverse_ids], counts_received), + torch.split(input[seed_inverse_ids], counts_received), ) return out From d20efc18f6116908425443a6bfca9067bd064edc Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 03:56:03 +0000 Subject: [PATCH 09/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 5f7060568bfc..37eea301dc1b 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -23,7 +23,7 @@ class CooperativeConvFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, subgraph: SampledSubgraph, input: torch.Tensor): + def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor): """Implements the forward pass.""" counts_sent = subgraph._counts_sent counts_received = subgraph._counts_received @@ -32,10 +32,10 @@ def forward(ctx, subgraph: SampledSubgraph, input: torch.Tensor): ctx.save_for_backward( counts_sent, counts_received, seed_inverse_ids, seed_sizes ) - out = input.new_empty((sum(counts_sent),) + input.shape[1:]) + out = tensor.new_empty((sum(counts_sent),) + tensor.shape[1:]) all_to_all( torch.split(out, counts_sent), - torch.split(input[seed_inverse_ids], counts_received), + torch.split(tensor[seed_inverse_ids], counts_received), ) return out From e4b79250a89d76b4b8dbe32b68a6d74ce3e527a4 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 04:08:54 +0000 Subject: [PATCH 10/11] make the conv hetero. --- python/dgl/graphbolt/impl/cooperative_conv.py | 59 +++++++++++-------- python/dgl/graphbolt/subgraph_sampler.py | 8 +++ 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 37eea301dc1b..5db23eef68e1 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -2,7 +2,7 @@ import torch from ..sampled_subgraph import SampledSubgraph -from ..subgraph_sampler import all_to_all +from ..subgraph_sampler import all_to_all, convert_to_hetero, revert_to_homo __all__ = ["CooperativeConvFunction", "CooperativeConv"] @@ -25,19 +25,26 @@ class CooperativeConvFunction(torch.autograd.Function): @staticmethod def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor): """Implements the forward pass.""" - counts_sent = subgraph._counts_sent - counts_received = subgraph._counts_received - seed_inverse_ids = subgraph._seed_inverse_ids - seed_sizes = subgraph._seed_sizes + counts_sent = convert_to_hetero(subgraph._counts_sent) + counts_received = convert_to_hetero(subgraph._counts_received) + seed_inverse_ids = convert_to_hetero(subgraph._seed_inverse_ids) + seed_sizes = convert_to_hetero(subgraph._seed_sizes) ctx.save_for_backward( counts_sent, counts_received, seed_inverse_ids, seed_sizes ) - out = tensor.new_empty((sum(counts_sent),) + tensor.shape[1:]) - all_to_all( - torch.split(out, counts_sent), - torch.split(tensor[seed_inverse_ids], counts_received), - ) - return out + outs = {} + for ntype, tensor in convert_to_hetero(tensor).items(): + out = tensor.new_empty( + (sum(counts_sent[ntype]),) + tensor.shape[1:] + ) + all_to_all( + torch.split(out, counts_sent[ntype]), + torch.split( + tensor[seed_inverse_ids[ntype]], counts_received[ntype] + ), + ) + outs[ntype] = out + return revert_to_homo(out) @staticmethod def backward(ctx, grad_output): @@ -48,19 +55,23 @@ def backward(ctx, grad_output): seed_inverse_ids, seed_sizes, ) = ctx.saved_tensors - out = grad_output.new_empty( - (sum(counts_received),) + grad_output.shape[1:] - ) - all_to_all( - torch.split(out, counts_received), - torch.split(grad_output, counts_sent), - ) - i = out.new_empty(2, out.shape[0], dtype=torch.int64) - i[0] = torch.arange(out.shape[0], device=grad_output.device) # src - i[1] = seed_inverse_ids # dst - coo = torch.sparse_coo_tensor(i, 1, size=(seed_sizes, i.shape[1])) - rout = torch.sparse.mm(coo, out) - return None, rout + outs = {} + for ntype, grad_output in convert_to_hetero(grad_output).items(): + out = grad_output.new_empty( + (sum(counts_received[ntype]),) + grad_output.shape[1:] + ) + all_to_all( + torch.split(out, counts_received[ntype]), + torch.split(grad_output, counts_sent[ntype]), + ) + i = out.new_empty(2, out.shape[0], dtype=torch.int64) + i[0] = torch.arange(out.shape[0], device=grad_output.device) # src + i[1] = seed_inverse_ids[ntype] # dst + coo = torch.sparse_coo_tensor( + i, 1, size=(seed_sizes[ntype], i.shape[1]) + ) + outs[ntype] = torch.sparse.mm(coo, out) + return None, revert_to_homo(outs) class CooperativeConv(torch.nn.Module): diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 01ba666465ad..dd5093ae5f69 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -16,6 +16,7 @@ __all__ = [ "SubgraphSampler", "all_to_all", + "convert_to_hetero", "revert_to_homo", ] @@ -89,6 +90,13 @@ def revert_to_homo(d: dict): return list(d.values())[0] if is_homogenous else d +def convert_to_hetero(item): + """Utility function to convert homogenous data to heterogenous with a single + node type.""" + is_heterogenous = isinstance(item, dict) + return item if is_heterogenous else {"_N": item} + + @functional_datapipe("sample_subgraph") class SubgraphSampler(MiniBatchTransformer): """A subgraph sampler used to sample a subgraph from a given set of nodes From 02f49d2c22073ba15757809373767d9ac780a594 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 13 Sep 2024 04:14:47 +0000 Subject: [PATCH 11/11] linting --- python/dgl/graphbolt/impl/cooperative_conv.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 5db23eef68e1..28f11bc8b317 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -1,4 +1,6 @@ """Graphbolt cooperative convolution.""" +from typing import Dict, Union + import torch from ..sampled_subgraph import SampledSubgraph @@ -23,7 +25,11 @@ class CooperativeConvFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor): + def forward( + ctx, + subgraph: SampledSubgraph, + tensor: Union[torch.Tensor, Dict[str, torch.Tensor]], + ): """Implements the forward pass.""" counts_sent = convert_to_hetero(subgraph._counts_sent) counts_received = convert_to_hetero(subgraph._counts_received) @@ -33,21 +39,24 @@ def forward(ctx, subgraph: SampledSubgraph, tensor: torch.Tensor): counts_sent, counts_received, seed_inverse_ids, seed_sizes ) outs = {} - for ntype, tensor in convert_to_hetero(tensor).items(): - out = tensor.new_empty( - (sum(counts_sent[ntype]),) + tensor.shape[1:] + for ntype, typed_tensor in convert_to_hetero(tensor).items(): + out = typed_tensor.new_empty( + (sum(counts_sent[ntype]),) + typed_tensor.shape[1:] ) all_to_all( torch.split(out, counts_sent[ntype]), torch.split( - tensor[seed_inverse_ids[ntype]], counts_received[ntype] + typed_tensor[seed_inverse_ids[ntype]], + counts_received[ntype], ), ) outs[ntype] = out return revert_to_homo(out) @staticmethod - def backward(ctx, grad_output): + def backward( + ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]] + ): """Implements the forward pass.""" ( counts_sent, @@ -56,16 +65,18 @@ def backward(ctx, grad_output): seed_sizes, ) = ctx.saved_tensors outs = {} - for ntype, grad_output in convert_to_hetero(grad_output).items(): - out = grad_output.new_empty( - (sum(counts_received[ntype]),) + grad_output.shape[1:] + for ntype, typed_grad_output in convert_to_hetero(grad_output).items(): + out = typed_grad_output.new_empty( + (sum(counts_received[ntype]),) + typed_grad_output.shape[1:] ) all_to_all( torch.split(out, counts_received[ntype]), - torch.split(grad_output, counts_sent[ntype]), + torch.split(typed_grad_output, counts_sent[ntype]), ) i = out.new_empty(2, out.shape[0], dtype=torch.int64) - i[0] = torch.arange(out.shape[0], device=grad_output.device) # src + i[0] = torch.arange( + out.shape[0], device=typed_grad_output.device + ) # src i[1] = seed_inverse_ids[ntype] # dst coo = torch.sparse_coo_tensor( i, 1, size=(seed_sizes[ntype], i.shape[1]) @@ -89,6 +100,10 @@ class CooperativeConv(torch.nn.Module): GPUs at the expense of communication. """ - def forward(self, subgraph: SampledSubgraph, x: torch.Tensor): + def forward( + self, + subgraph: SampledSubgraph, + x: Union[torch.Tensor, Dict[str, torch.Tensor]], + ): """Implements the forward pass.""" return CooperativeConvFunction.apply(subgraph, x)