Skip to content

Commit

Permalink
[GraphBolt] Implement dependent minibatching for labor. (#7205)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Mar 11, 2024
1 parent 2bda158 commit f0c7efa
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 46 deletions.
97 changes: 97 additions & 0 deletions graphbolt/include/graphbolt/continuous_seed.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file graphbolt/continuous_seed.h
* @brief CPU and CUDA implementation for continuous random seeds
*/
#ifndef GRAPHBOLT_CONTINUOUS_SEED_H_
#define GRAPHBOLT_CONTINUOUS_SEED_H_

#include <torch/script.h>

#include <cmath>

#ifdef __CUDACC__
#include <curand_kernel.h>
#else
#include <pcg_random.hpp>
#include <random>
#endif // __CUDA_ARCH__

#ifndef M_SQRT1_2
#define M_SQRT1_2 0.707106781186547524401
#endif // M_SQRT1_2

namespace graphbolt {

class continuous_seed {
uint64_t s[2];
float c[2];

public:
/* implicit */ continuous_seed(const int64_t seed) { // NOLINT
s[0] = s[1] = seed;
c[0] = c[1] = 0;
}

continuous_seed(torch::Tensor seed_arr, float r) {
auto seed = seed_arr.data_ptr<int64_t>();
s[0] = seed[0];
s[1] = seed[seed_arr.size(0) - 1];
const auto pi = std::acos(-1.0);
c[0] = std::cos(pi * r / 2);
c[1] = std::sin(pi * r / 2);
}

#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t t) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, s[0], t, &rng);
float rnd;
if (s[0] != s[1]) {
rnd = c[0] * curand_normal(&rng);
curand_init(kCurandSeed, s[1], t, &rng);
rnd += c[1] * curand_normal(&rng);
rnd = normcdff(rnd);
} else {
rnd = curand_uniform(&rng);
}
return rnd;
}
#else
inline float uniform(const uint64_t t) const {
pcg32 ng0(s[0], t);
float rnd;
if (s[0] != s[1]) {
std::normal_distribution<float> norm;
rnd = c[0] * norm(ng0);
pcg32 ng1(s[1], t);
norm.reset();
rnd += c[1] * norm(ng1);
rnd = std::erfc(-rnd * static_cast<float>(M_SQRT1_2)) / 2.0f;
} else {
std::uniform_real_distribution<float> uni;
rnd = uni(ng0);
}
return rnd;
}
#endif // __CUDA_ARCH__
};

} // namespace graphbolt

#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
3 changes: 2 additions & 1 deletion graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_

#include <graphbolt/continuous_seed.h>
#include <graphbolt/fused_sampled_subgraph.h>
#include <graphbolt/shared_memory.h>
#include <torch/torch.h>
Expand All @@ -27,7 +28,7 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};
template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
int64_t random_seed;
continuous_seed random_seed;
int64_t num_nodes;
};

Expand Down
21 changes: 6 additions & 15 deletions graphbolt/src/cuda/neighbor_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
#include <c10/core/ScalarType.h>
#include <curand_kernel.h>
#include <graphbolt/continuous_seed.h>
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
Expand Down Expand Up @@ -41,27 +42,17 @@ __global__ void _ComputeRandoms(
const int64_t num_edges, const indptr_t* const sliced_indptr,
const indptr_t* const sub_indptr, const indices_t* const csr_rows,
const weights_t* const sliced_weights, const indices_t* const indices,
const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
const continuous_seed random_seed, float_t* random_arr,
edge_id_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x;
curandStatePhilox4_32_10_t rng;
const auto labor = indices != nullptr;

if (!labor) {
curand_init(random_seed, i, 0, &rng);
}

while (i < num_edges) {
const auto row_position = csr_rows[i];
const auto row_offset = i - sub_indptr[row_position];
const auto in_idx = sliced_indptr[row_position] + row_offset;

if (labor) {
constexpr uint64_t kCurandSeed = 999961;
curand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
}

const auto rnd = curand_uniform(&rng);
const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i);
const auto prob =
sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
const auto exp_rnd = -__logf(rnd);
Expand Down Expand Up @@ -211,8 +202,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto coo_rows = ExpandIndptrImpl(
sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
num_edges = coo_rows.size(0);
const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
const continuous_seed random_seed(RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()));
auto output_indptr = torch::empty_like(sub_indptr);
torch::Tensor picked_eids;
torch::Tensor output_indices;
Expand Down
25 changes: 21 additions & 4 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,25 @@ inline void safe_divide(T& a, U b) {
a = b > 0 ? (T)(a / b) : std::numeric_limits<T>::infinity();
}

namespace labor {

template <typename T>
inline T invcdf(T u, int64_t n, T rem) {
constexpr T one = 1;
return rem * (one - std::pow(one - u, one / n));
}

template <typename T>
inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
return 1 - rem;
}

}; // namespace labor

/**
* @brief Perform uniform-nonuniform sampling of elements depending on the
* template parameter NonUniform and return the sampled indices.
Expand Down Expand Up @@ -1563,8 +1582,7 @@ inline int64_t LaborPick(
// O(num_neighbors).
for (uint32_t i = 0; i < fanout; ++i) {
const auto t = local_indices_data[i];
auto rnd =
labor::uniform_random<float>(args.random_seed, t); // r_t
auto rnd = args.random_seed.uniform(t); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
Expand All @@ -1575,8 +1593,7 @@ inline int64_t LaborPick(
}
for (uint32_t i = fanout; i < num_neighbors; ++i) {
const auto t = local_indices_data[i];
auto rnd =
labor::uniform_random<float>(args.random_seed, t); // r_t
auto rnd = args.random_seed.uniform(t); // r_t
if constexpr (NonUniform) {
safe_divide(rnd, local_probs_data[i]);
} // r_t / \pi_t
Expand Down
26 changes: 0 additions & 26 deletions graphbolt/src/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,32 +76,6 @@ class RandomEngine {
pcg32 rng_;
};

namespace labor {

template <typename T>
inline T uniform_random(int64_t random_seed, int64_t t) {
pcg32 ng(random_seed, t);
std::uniform_real_distribution<T> uni;
return uni(ng);
}

template <typename T>
inline T invcdf(T u, int64_t n, T rem) {
constexpr T one = 1;
return rem * (one - std::pow(one - u, one / n));
}

template <typename T>
inline T jth_sorted_uniform_random(
int64_t random_seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const auto u = uniform_random<T>(random_seed, t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
return 1 - rem;
}

}; // namespace labor

} // namespace graphbolt

#endif // GRAPHBOLT_RANDOM_H_

0 comments on commit f0c7efa

Please sign in to comment.