Skip to content

Commit

Permalink
implement warp shuffle based reduction; enable for arch >= 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
killeent committed Apr 18, 2017
1 parent edf3c71 commit 4990147
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 2 deletions.
102 changes: 100 additions & 2 deletions lib/THC/THCReduceApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "THCTensor.h"
#include "THCDeviceUtils.cuh"
#include "THCTensorInfo.cuh"
#include "THCAsmUtils.cuh"

// Enum that indicates whether tensor arguments are read/write or
// read-only
Expand All @@ -26,8 +27,101 @@ __device__ __forceinline__ IndexType getLinearBlockId() {
// level) with N elements per thread in the block, so we have to use min(numvals,
// max block size) to determine this count.
template <typename T, int N>
int reduceSmemSize(THCState *state, int numVals) {
return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T);
int reduceSmemSize(THCState *state, long numVals) {
// check if we can use a warp shuffle
cudaDeviceProp *props = THCState_getCurrentDeviceProperties(state);
if (props->major >= 3) {
return props->warpSize * N * sizeof(T);
} else {
return THCRoundUp(std::min(numVals, (long) props->maxThreadsPerBlock), (long) props->warpSize) * N * sizeof(T);
}
}

template <typename T>
struct THCWarpUtils {
static __device__ __forceinline__ T shflxor(T val, unsigned int mask) {
return __shfl_xor(val, mask);
}
};

template <>
struct THCWarpUtils<unsigned char> {
static __device__ __forceinline__ unsigned char shflxor(unsigned char val, unsigned int mask) {
return (unsigned char) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<char> {
static __device__ __forceinline__ char shflxor(char val, unsigned int mask) {
return (char) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<short> {
static __device__ __forceinline__ short shflxor(short val, unsigned int mask) {
return (short) __shfl_xor((int) val, mask);
}
};

template <>
struct THCWarpUtils<double> {
static __device__ __forceinline__ double shflxor(double val, unsigned int mask) {
int2 a = *reinterpret_cast<int2*>(&val);
a.x = __shfl_xor(a.x, mask);
a.y = __shfl_xor(a.y, mask);
return *reinterpret_cast<double*>(&a);
}
};

template <>
struct THCWarpUtils<long> {
static __device__ __forceinline__ long shflxor(long val, unsigned int mask) {
int2 a = *reinterpret_cast<int2*>(&val);
a.x = __shfl_xor(a.x, mask);
a.y = __shfl_xor(a.y, mask);
return *reinterpret_cast<long*>(&a);
}
};

template <typename T, typename ReduceOp, int N>
__device__ void warpReduce(T threadVals[N], ReduceOp reduceOp) {
#pragma unroll
for (int mask = 1; mask < warpSize; mask *= 2) {
#pragma unroll
for (int i = 0; i < N; ++i) {
T neighbor = THCWarpUtils<T>::shflxor(threadVals[i], mask);
threadVals[i] = reduceOp(threadVals[i], neighbor);
}
}
}

template <typename T, typename ReduceOp, int N>
__device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) {
assert(blockDim.x % warpSize == 0);
// First, warps cooperate to reduce values within the warp
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
int lane = getLaneId();
int warp = threadIdx.x / warpSize;

if (lane == 0) {

#pragma unroll
for (int i = 0; i < N; ++i) {
smem[warp + (i * warpSize)] = threadVals[i];
}
}
__syncthreads();

#pragma unroll
for (int i = 0; i < N; ++i) {
threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init;
}

if (warp == 0) {
warpReduce<T, ReduceOp, N>(threadVals, reduceOp);
}
}

// Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads:
Expand All @@ -47,6 +141,9 @@ __device__ void reduceNValuesInBlock(T *smem,
return;
}

#if __CUDA_ARCH__ >= 300
warpReduceBlock<T, ReduceOp, N>(smem, threadVals, numVals, reduceOp, init);
#else
// We store each of the N values contiguously, so if N = 2, all values for
// the first threadVal for each thread in the block are stored followed by
// all of the values for the second threadVal for each thread in the block
Expand Down Expand Up @@ -102,6 +199,7 @@ __device__ void reduceNValuesInBlock(T *smem,
}
}
}
#endif
}

// Block-wide reduction in shared memory helper; only threadIdx.x == 0 will
Expand Down
18 changes: 18 additions & 0 deletions lib/THC/THCTensorMode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,31 @@ struct ModeUnsignedBoolPair {
bool flag;
};

template <>
struct THCWarpUtils<ModeUnsignedBoolPair> {
static __device__ __forceinline__ ModeUnsignedBoolPair shflxor(ModeUnsignedBoolPair val, unsigned int mask) {
val.val = __shfl_xor(val.val, mask);
val.flag = (bool) __shfl_xor((int) val.flag, mask);
return val;
}
};

// In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int)
// pairs of data
struct ModeUnsignedPair {
unsigned int val;
unsigned int index;
};

template <>
struct THCWarpUtils<ModeUnsignedPair> {
static __device__ __forceinline__ ModeUnsignedPair shflxor(ModeUnsignedPair val, unsigned int mask) {
val.val = __shfl_xor(val.val, mask);
val.index = __shfl_xor(val.index, mask);
return val;
}
};

template <typename T>
struct MaxReduceOp {
__host__ __device__ inline T operator()(const T& a, const T& b) {
Expand Down

0 comments on commit 4990147

Please sign in to comment.