Skip to content

Commit

Permalink
[GraphBolt][CUDA] Cooperative Minibatching - Feature Loading (#7798)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Sep 15, 2024
1 parent 864b023 commit 55c224a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
41 changes: 38 additions & 3 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.utils.data import functional_datapipe

from .base import etype_tuple_to_str
from .impl.cooperative_conv import CooperativeConvFunction

from .minibatch_transformer import MiniBatchTransformer

Expand Down Expand Up @@ -73,6 +74,16 @@ class FeatureFetcher(MiniBatchTransformer):
If True, the feature fetcher will overlap the UVA feature fetcher
operations with the rest of operations by using an alternative CUDA
stream or utilizing asynchronous operations. Default is True.
cooperative: bool, optional
Boolean indicating whether Cooperative Minibatching, which was initially
proposed in
`Deep Graph Library PR#4337<https://github.com/dmlc/dgl/pull/4337>`__
and was later first fully described in
`Cooperative Minibatching in Graph Neural Networks
<https://arxiv.org/abs/2310.12403>`__. Cooperation between the GPUs
eliminates duplicate work performed across the GPUs due to the
overlapping sampled k-hop neighborhoods of seed nodes when performing
GNN minibatching.
"""

def __init__(
Expand All @@ -82,6 +93,7 @@ def __init__(
node_feature_keys=None,
edge_feature_keys=None,
overlap_fetch=True,
cooperative=False,
):
datapipe = datapipe.mark_feature_fetcher_start()
self.feature_store = feature_store
Expand Down Expand Up @@ -113,9 +125,12 @@ def __init__(
datapipe = datapipe.transform(
partial(self._execute_stage, i)
).buffer(1)
super().__init__(
datapipe, self._identity if max_val == 0 else self._final_stage
)
if max_val > 0:
datapipe = datapipe.transform(self._final_stage)
if cooperative:
datapipe = datapipe.transform(self._cooperative_exchange)
datapipe = datapipe.buffer()
super().__init__(datapipe)
# A positive value indicates that the overlap optimization is enabled.
self.max_num_stages = max_val

Expand Down Expand Up @@ -145,6 +160,26 @@ def _final_stage(data):
features[key] = value.wait()
return data

def _cooperative_exchange(self, data):
subgraph = data.sampled_subgraphs[0]
is_heterogeneous = isinstance(
self.node_feature_keys, Dict
) or isinstance(self.edge_feature_keys, Dict)
if is_heterogeneous:
node_features = {key: {} for key, _ in data.node_features.keys()}
for (key, ntype), feature in data.node_features.items():
node_features[key][ntype] = feature
for key, feature in node_features.items():
new_feature = CooperativeConvFunction.apply(subgraph, feature)
for ntype, tensor in new_feature.items():
data.node_features[(key, ntype)] = tensor
else:
for key in data.node_features:
feature = data.node_features[key]
new_feature = CooperativeConvFunction.apply(subgraph, feature)
data.node_features[key] = new_feature
return data

def _read(self, data):
"""
Fill in the node/edge features field in data.
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/impl/cooperative_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(
def backward(
ctx, grad_output: Union[torch.Tensor, Dict[str, torch.Tensor]]
):
"""Implements the forward pass."""
"""Implements the backward pass."""
(
counts_sent,
counts_received,
Expand Down
3 changes: 3 additions & 0 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_gpu_sampling_DataLoader(
["a", "b", "c"],
["d"],
overlap_fetch=overlap_feature_fetch and i == 0,
cooperative=asynchronous and cooperative and i == 0,
)
dataloaders.append(dgl.graphbolt.DataLoader(datapipe))
dataloader, dataloader2 = dataloaders
Expand All @@ -159,6 +160,8 @@ def test_gpu_sampling_DataLoader(
bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1.
if cooperative:
bufferer_cnt += 3 * num_layers
if enable_feature_fetch:
bufferer_cnt += 1 # feature fetch has 1.
if cooperative:
# _preprocess stage and each sampling layer.
bufferer_cnt += 3
Expand Down

0 comments on commit 55c224a

Please sign in to comment.