-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GraphBolt][CUDA] Expose
RankSort
to python, reorganize and test. (#…
- Loading branch information
Showing
6 changed files
with
125 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/** | ||
* Copyright (c) 2024, mfbalin (Muhammed Fatih Balin) | ||
* All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* | ||
* @file cuda/cooperative_minibatching_utils.cuh | ||
* @brief Cooperative Minibatching (arXiv:2310.12403) utility device functions | ||
* in CUDA. | ||
*/ | ||
#ifndef GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_ | ||
#define GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_ | ||
|
||
#include <curand_kernel.h> | ||
|
||
namespace graphbolt { | ||
namespace cuda { | ||
|
||
using part_t = uint8_t; | ||
constexpr auto kPartDType = torch::kUInt8; | ||
|
||
/** | ||
* @brief Given a vertex id, the rank of current GPU and the world size, returns | ||
* the rank that this id belongs in a deterministic manner. | ||
* | ||
* @param id The node id that will mapped to a rank in [0, world_size). | ||
* @param rank The rank of the current GPU. | ||
* @param world_size The world size, the total number of cooperating GPUs. | ||
* | ||
* @return The rank of the GPU the given id is mapped to. | ||
*/ | ||
template <typename index_t> | ||
__device__ inline auto rank_assignment( | ||
index_t id, uint32_t rank, uint32_t world_size) { | ||
// Consider using a faster implementation in the future. | ||
constexpr uint64_t kCurandSeed = 999961; // Any random number. | ||
curandStatePhilox4_32_10_t rng; | ||
curand_init(kCurandSeed, 0, id, &rng); | ||
return (curand(&rng) - rank) % world_size; | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace graphbolt | ||
|
||
#endif // GRAPHBOLT_CUDA_COOPERATIVE_MINIBATCHING_UTILS_CUH_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import unittest | ||
|
||
from functools import partial | ||
|
||
import backend as F | ||
import dgl.graphbolt as gb | ||
import pytest | ||
import torch | ||
|
||
WORLD_SIZE = 7 | ||
|
||
assert_equal = partial(torch.testing.assert_close, rtol=0, atol=0) | ||
|
||
|
||
@unittest.skipIf( | ||
F._default_context_str != "gpu", | ||
reason="This test requires an NVIDIA GPU.", | ||
) | ||
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) | ||
@pytest.mark.parametrize("rank", list(range(WORLD_SIZE))) | ||
def test_gpu_cached_feature_read_async(dtype, rank): | ||
nodes_list1 = [ | ||
torch.randint(0, 11111111, [777], dtype=dtype, device=F.ctx()) | ||
for i in range(10) | ||
] | ||
nodes_list2 = [nodes.sort()[0] for nodes in nodes_list1] | ||
|
||
res1 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) | ||
res2 = torch.ops.graphbolt.rank_sort(nodes_list2, rank, WORLD_SIZE) | ||
|
||
for i, ((nodes1, idx1, offsets1), (nodes2, idx2, offsets2)) in enumerate( | ||
zip(res1, res2) | ||
): | ||
assert_equal(nodes_list1[i], nodes1[idx1.sort()[1]]) | ||
assert_equal(nodes_list2[i], nodes2[idx2.sort()[1]]) | ||
assert_equal(offsets1, offsets2) | ||
assert offsets1.is_pinned() and offsets2.is_pinned() | ||
|
||
res3 = torch.ops.graphbolt.rank_sort(nodes_list1, rank, WORLD_SIZE) | ||
|
||
# This function is deterministic. Call with identical arguments and check. | ||
for (nodes1, idx1, offsets1), (nodes3, idx3, offsets3) in zip(res1, res3): | ||
assert_equal(nodes1, nodes3) | ||
assert_equal(idx1, idx3) | ||
assert_equal(offsets1, offsets3) | ||
|
||
# The dependency on the rank argument is simply a permutation. | ||
res4 = torch.ops.graphbolt.rank_sort(nodes_list1, 0, WORLD_SIZE) | ||
for (nodes1, idx1, offsets1), (nodes4, idx4, offsets4) in zip(res1, res4): | ||
off1 = offsets1.tolist() | ||
off4 = offsets4.tolist() | ||
for i in range(WORLD_SIZE): | ||
j = (i - rank + WORLD_SIZE) % WORLD_SIZE | ||
assert_equal( | ||
nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]] | ||
) | ||
assert_equal( | ||
idx1[off1[j] : off1[j + 1]], idx4[off4[i] : off4[i + 1]] | ||
) |