diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 75dab4d54cd2..cf9d5f4104c2 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -8,6 +8,7 @@ from torch.utils.data import functional_datapipe from .base import etype_tuple_to_str +from .impl.cooperative_conv import CooperativeConvFunction from .minibatch_transformer import MiniBatchTransformer @@ -73,6 +74,16 @@ class FeatureFetcher(MiniBatchTransformer): If True, the feature fetcher will overlap the UVA feature fetcher operations with the rest of operations by using an alternative CUDA stream or utilizing asynchronous operations. Default is True. + cooperative: bool, optional + Boolean indicating whether Cooperative Minibatching, which was initially + proposed in + `Deep Graph Library PR#4337`__ + and was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__. Cooperation between the GPUs + eliminates duplicate work performed across the GPUs due to the + overlapping sampled k-hop neighborhoods of seed nodes when performing + GNN minibatching. """ def __init__( @@ -82,6 +93,7 @@ def __init__( node_feature_keys=None, edge_feature_keys=None, overlap_fetch=True, + cooperative=False, ): datapipe = datapipe.mark_feature_fetcher_start() self.feature_store = feature_store @@ -113,9 +125,12 @@ def __init__( datapipe = datapipe.transform( partial(self._execute_stage, i) ).buffer(1) - super().__init__( - datapipe, self._identity if max_val == 0 else self._final_stage - ) + if max_val > 0: + datapipe = datapipe.transform(self._final_stage) + if cooperative: + datapipe = datapipe.transform(self._cooperative_exchange) + datapipe = datapipe.buffer() + super().__init__(datapipe) # A positive value indicates that the overlap optimization is enabled. self.max_num_stages = max_val @@ -145,6 +160,26 @@ def _final_stage(data): features[key] = value.wait() return data + def _cooperative_exchange(self, data): + subgraph = data.sampled_subgraphs[0] + is_heterogeneous = isinstance( + self.node_feature_keys, Dict + ) or isinstance(self.edge_feature_keys, Dict) + if is_heterogeneous: + node_features = {key: {} for key, _ in data.node_features.keys()} + for (key, ntype), feature in data.node_features.items(): + node_features[key][ntype] = feature + for key, feature in node_features.items(): + new_feature = CooperativeConvFunction.apply(subgraph, feature) + for ntype, tensor in new_feature.items(): + data.node_features[(key, ntype)] = tensor + else: + for key in data.node_features: + feature = data.node_features[key] + new_feature = CooperativeConvFunction.apply(subgraph, feature) + data.node_features[key] = new_feature + return data + def _read(self, data): """ Fill in the node/edge features field in data. diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 28f11bc8b317..cb3d39d4d980 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -57,7 +57,7 @@ def forward( def backward( ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]] ): - """Implements the forward pass.""" + """Implements the backward pass.""" ( counts_sent, counts_received, diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 5d5d44fd1eb7..ee8f2b0cb9f5 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -146,6 +146,7 @@ def test_gpu_sampling_DataLoader( ["a", "b", "c"], ["d"], overlap_fetch=overlap_feature_fetch and i == 0, + cooperative=asynchronous and cooperative and i == 0, ) dataloaders.append(dgl.graphbolt.DataLoader(datapipe)) dataloader, dataloader2 = dataloaders @@ -159,6 +160,8 @@ def test_gpu_sampling_DataLoader( bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1. if cooperative: bufferer_cnt += 3 * num_layers + if enable_feature_fetch: + bufferer_cnt += 1 # feature fetch has 1. if cooperative: # _preprocess stage and each sampling layer. bufferer_cnt += 3