Skip to content

Commit

Permalink
[GraphBolt][CUDA] Expose UniqueAndCompact offsets. (#7789)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Sep 9, 2024
1 parent bbe00c0 commit e8022e9
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 54 deletions.
11 changes: 8 additions & 3 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,17 @@ torch::Tensor IndptrEdgeIdsImpl(
* @param rank The rank of the current GPU.
* @param world_size The total # GPUs, world size.
*
* @return
* @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets)
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the offsets into the unique_ids tensor. Has
* size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to
* the rank `(rank + i) % world_size`.
*
* @example
* torch::Tensor src_ids = src
Expand All @@ -306,7 +309,8 @@ torch::Tensor IndptrEdgeIdsImpl(
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, const int64_t rank,
const int64_t world_size);
Expand All @@ -316,7 +320,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
* value is equal to the passing the ith elements of the input arguments to
* UniqueAndCompact.
*/
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
Expand Down
15 changes: 10 additions & 5 deletions graphbolt/include/graphbolt/unique_and_compact.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@ namespace sampling {
* @param rank The rank of the current GPU.
* @param world_size The total # GPUs, world size.
*
* @return
* @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets)
* - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after
* removing duplicates. The indices in this tensor precisely match the compacted
* IDs of the corresponding elements.
* - The tensor corresponding to the 'src_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the 'dst_ids' tensor, where the entries are
* mapped to compacted IDs.
* - The tensor corresponding to the offsets into the unique_ids tensor. Has
* size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to
* the rank `(rank + i) % world_size`.
*
* @example
* torch::Tensor src_ids = src
Expand All @@ -56,20 +59,22 @@ namespace sampling {
* torch::Tensor compacted_src_ids = std::get<1>(result);
* torch::Tensor compacted_dst_ids = std::get<2>(result);
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids, const int64_t rank,
const int64_t world_size);

std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
const std::vector<torch::Tensor> unique_dst_ids, const int64_t rank,
const int64_t world_size);

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
c10::intrusive_ptr<Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>
UniqueAndCompactBatchedAsync(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
Expand Down
31 changes: 20 additions & 11 deletions graphbolt/src/cuda/unique_and_compact_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ UniqueAndCompactBatchedSortBased(
}));
}

std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
Expand All @@ -282,15 +283,8 @@ UniqueAndCompactBatched(
// Utilizes a hash table based implementation, the mapped id of a vertex
// will be monotonically increasing as the first occurrence index of it in
// torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
auto results4 = UniqueAndCompactBatchedHashMapBased(
return UniqueAndCompactBatchedHashMapBased(
src_ids, dst_ids, unique_dst_ids, rank, world_size);
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
results3;
// TODO @mfbalin: expose the `d` result in a later PR.
for (const auto& [a, b, c, d] : results4) {
results3.emplace_back(a, b, c);
}
return results3;
}
TORCH_CHECK(
world_size <= 1,
Expand All @@ -299,10 +293,25 @@ UniqueAndCompactBatched(
// Utilizes a sort based algorithm, the mapped id of a vertex part of the
// src_ids but not part of the unique_dst_ids will be monotonically increasing
// as the actual vertex id increases. Thus, it is deterministic.
return UniqueAndCompactBatchedSortBased(src_ids, dst_ids, unique_dst_ids);
auto results3 =
UniqueAndCompactBatchedSortBased(src_ids, dst_ids, unique_dst_ids);
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
results4;
auto offsets = torch::zeros(
2 * results3.size(),
c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));
for (const auto& [a, b, c] : results3) {
auto d = offsets.slice(0, 0, 2);
d.data_ptr<int64_t>()[1] = a.size(0);
results4.emplace_back(a, b, c, d);
offsets = offsets.slice(0, 2);
}
return results4;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
UniqueAndCompact(
const torch::Tensor src_ids, const torch::Tensor dst_ids,
const torch::Tensor unique_dst_ids, const int64_t rank,
const int64_t world_size) {
Expand Down
9 changes: 5 additions & 4 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ TORCH_LIBRARY(graphbolt, m) {
m.class_<Future<c10::intrusive_ptr<FusedSampledSubgraph>>>(
"FusedSampledSubgraphFuture")
.def("wait", &Future<c10::intrusive_ptr<FusedSampledSubgraph>>::Wait);
m.class_<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>(
m.class_<Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>(
"UniqueAndCompactBatchedFuture")
.def(
"wait",
&Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
&Future<std::vector<std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>::
Wait);
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
"GpuGraphCacheQueryFuture")
.def(
Expand Down
19 changes: 13 additions & 6 deletions graphbolt/src/unique_and_compact.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

namespace graphbolt {
namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids, const int64_t rank,
const int64_t world_size) {
Expand All @@ -31,16 +32,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
"Cooperative Minibatching (arXiv:2310.12403) is supported only on GPUs.");
auto num_dst = unique_dst_ids.size(0);
torch::Tensor ids = torch::cat({unique_dst_ids, src_ids});
return AT_DISPATCH_INDEX_TYPES(
auto [unique_ids, compacted_src, compacted_dst] = AT_DISPATCH_INDEX_TYPES(
ids.scalar_type(), "unique_and_compact", ([&] {
ConcurrentIdHashMap<index_t> id_map(ids, num_dst);
return std::make_tuple(
id_map.GetUniqueIds(), id_map.MapIds(src_ids),
id_map.MapIds(dst_ids));
}));
auto offsets = torch::zeros(2, c10::TensorOptions().dtype(torch::kInt64));
offsets.data_ptr<int64_t>()[1] = unique_ids.size(0);
return {unique_ids, compacted_src, compacted_dst, offsets};
}

std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
Expand All @@ -64,7 +69,9 @@ UniqueAndCompactBatched(
src_ids, dst_ids, unique_dst_ids, rank, world_size);
});
}
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> results;
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
results;
results.reserve(src_ids.size());
for (std::size_t i = 0; i < src_ids.size(); i++) {
results.emplace_back(UniqueAndCompact(
Expand All @@ -73,8 +80,8 @@ UniqueAndCompactBatched(
return results;
}

c10::intrusive_ptr<Future<
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>>
c10::intrusive_ptr<Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>>>
UniqueAndCompactBatchedAsync(
const std::vector<torch::Tensor>& src_ids,
const std::vector<torch::Tensor>& dst_ids,
Expand Down
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/in_subgraph_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def sample_subgraphs(
(
original_row_node_ids,
compacted_csc_formats,
_,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_formats,
Expand Down
7 changes: 6 additions & 1 deletion python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def _compact_per_layer(self, minibatch):
(
original_row_node_ids,
compacted_csc_format,
_,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
Expand Down Expand Up @@ -506,7 +507,11 @@ def _compact_per_layer_async(self, minibatch):
def _compact_per_layer_wait_future(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
original_row_node_ids, compacted_csc_format = minibatch._future.wait()
(
original_row_node_ids,
compacted_csc_format,
_,
) = minibatch._future.wait()
delattr(minibatch, "_future")
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
Expand Down
Loading

0 comments on commit e8022e9

Please sign in to comment.