Skip to content

Commit

Permalink
[GraphBolt][CUDA] rank_sort_async for Cooperative Minibatching. (#7805
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mfbalin authored Sep 19, 2024
1 parent 31ad9b5 commit 5ae6400
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 12 deletions.
11 changes: 11 additions & 0 deletions graphbolt/src/cuda/cooperative_minibatching_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cub/cub.cuh>
#include <cuda/functional>

#include "../utils.h"
#include "./common.h"
#include "./cooperative_minibatching_utils.cuh"
#include "./cooperative_minibatching_utils.h"
Expand Down Expand Up @@ -144,5 +145,15 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
return results;
}

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
RankSortAsync(
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
const int64_t world_size) {
return async(
[=] { return RankSort(nodes_list, rank, world_size); },
utils::is_on_gpu(nodes_list.at(0)));
}

} // namespace cuda
} // namespace graphbolt
7 changes: 7 additions & 0 deletions graphbolt/src/cuda/cooperative_minibatching_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_H_

#include <ATen/cuda/CUDAEvent.h>
#include <graphbolt/async.h>
#include <torch/script.h>

namespace graphbolt {
Expand Down Expand Up @@ -83,6 +84,12 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
const std::vector<torch::Tensor>& nodes_list, int64_t rank,
int64_t world_size);

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
RankSortAsync(
const std::vector<torch::Tensor>& nodes_list, const int64_t rank,
const int64_t world_size);

} // namespace cuda
} // namespace graphbolt

Expand Down
8 changes: 8 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ TORCH_LIBRARY(graphbolt, m) {
&Future<std::vector<std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>::
Wait);
m.class_<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>(
"RankSortFuture")
.def(
"wait",
&Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
"GpuGraphCacheQueryFuture")
.def(
Expand Down Expand Up @@ -198,6 +205,7 @@ TORCH_LIBRARY(graphbolt, m) {
#ifdef GRAPHBOLT_USE_CUDA
m.def("set_max_uva_threads", &cuda::set_max_uva_threads);
m.def("rank_sort", &cuda::RankSort);
m.def("rank_sort_async", &cuda::RankSortAsync);
#endif
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub("dgl.graphbolt.base", "//dgl.graphbolt.base");
Expand Down
34 changes: 24 additions & 10 deletions python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __init__(
if cooperative:
datapipe = datapipe.transform(self._seeds_cooperative_exchange_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_1_wait_future
).buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_2)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._seeds_cooperative_exchange_3)
Expand Down Expand Up @@ -193,19 +196,32 @@ def _wait_preprocess_future(minibatch, cooperative: bool):
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
def _seeds_cooperative_exchange_1(minibatch):
rank = thd.get_rank()
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
minibatch._rank_sort_future = torch.ops.graphbolt.rank_sort_async(
list(seeds.values()), rank, world_size
)
return minibatch

@staticmethod
def _seeds_cooperative_exchange_1_wait_future(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
num_ntypes = len(seeds.keys())
if minibatch._seeds_offsets is None:
result = minibatch._rank_sort_future.wait()
delattr(minibatch, "_rank_sort_future")
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
(typed_sorted_seeds, typed_index, typed_offsets),
Expand All @@ -229,16 +245,15 @@ def _seeds_cooperative_exchange_1(minibatch, group=None):
minibatch._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
group=group,
async_op=True,
)
minibatch._counts_sent = counts_sent
minibatch._counts_received = counts_received
return minibatch

@staticmethod
def _seeds_cooperative_exchange_2(minibatch, group=None):
world_size = thd.get_world_size(group)
def _seeds_cooperative_exchange_2(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
Expand All @@ -256,7 +271,6 @@ def _seeds_cooperative_exchange_2(minibatch, group=None):
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
group,
)
seeds_received[ntype] = typed_seeds_received
counts_sent[ntype] = typed_counts_sent
Expand Down
4 changes: 2 additions & 2 deletions tests/python/pytorch/graphbolt/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_gpu_sampling_DataLoader(
if enable_feature_fetch:
bufferer_cnt += 1 # feature fetch has 1.
if cooperative:
# _preprocess stage and each sampling layer.
bufferer_cnt += 3
# _preprocess stage.
bufferer_cnt += 4
datapipe_graph = traverse_dps(dataloader)
bufferers = find_dps(
datapipe_graph,
Expand Down

0 comments on commit 5ae6400

Please sign in to comment.