diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh index 0e68f98c..6143084f 100644 --- a/lib/THC/THCReduceApplyUtils.cuh +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -8,6 +8,7 @@ #include "THCTensor.h" #include "THCDeviceUtils.cuh" #include "THCTensorInfo.cuh" +#include "THCAsmUtils.cuh" // Enum that indicates whether tensor arguments are read/write or // read-only @@ -26,8 +27,101 @@ __device__ __forceinline__ IndexType getLinearBlockId() { // level) with N elements per thread in the block, so we have to use min(numvals, // max block size) to determine this count. template -int reduceSmemSize(THCState *state, int numVals) { - return THCRoundUp(std::min(numVals, 1024), 32) * N * sizeof(T); +int reduceSmemSize(THCState *state, long numVals) { + // check if we can use a warp shuffle + cudaDeviceProp *props = THCState_getCurrentDeviceProperties(state); + if (props->major >= 3) { + return props->warpSize * N * sizeof(T); + } else { + return THCRoundUp(std::min(numVals, (long) props->maxThreadsPerBlock), (long) props->warpSize) * N * sizeof(T); + } +} + +template +struct THCWarpUtils { + static __device__ __forceinline__ T shflxor(T val, unsigned int mask) { + return __shfl_xor(val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ unsigned char shflxor(unsigned char val, unsigned int mask) { + return (unsigned char) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ char shflxor(char val, unsigned int mask) { + return (char) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ short shflxor(short val, unsigned int mask) { + return (short) __shfl_xor((int) val, mask); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ double shflxor(double val, unsigned int mask) { + int2 a = *reinterpret_cast(&val); + a.x = __shfl_xor(a.x, mask); + a.y = __shfl_xor(a.y, mask); + return *reinterpret_cast(&a); + } +}; + +template <> +struct THCWarpUtils { + static __device__ __forceinline__ long shflxor(long val, unsigned int mask) { + int2 a = *reinterpret_cast(&val); + a.x = __shfl_xor(a.x, mask); + a.y = __shfl_xor(a.y, mask); + return *reinterpret_cast(&a); + } +}; + +template +__device__ void warpReduce(T threadVals[N], ReduceOp reduceOp) { +#pragma unroll + for (int mask = 1; mask < warpSize; mask *= 2) { +#pragma unroll + for (int i = 0; i < N; ++i) { + T neighbor = THCWarpUtils::shflxor(threadVals[i], mask); + threadVals[i] = reduceOp(threadVals[i], neighbor); + } + } +} + +template +__device__ void warpReduceBlock(T *smem, T threadVals[N], int numVals, ReduceOp reduceOp, T init) { + assert(blockDim.x % warpSize == 0); + // First, warps cooperate to reduce values within the warp + warpReduce(threadVals, reduceOp); + int lane = getLaneId(); + int warp = threadIdx.x / warpSize; + + if (lane == 0) { + +#pragma unroll + for (int i = 0; i < N; ++i) { + smem[warp + (i * warpSize)] = threadVals[i]; + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < N; ++i) { + threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init; + } + + if (warp == 0) { + warpReduce(threadVals, reduceOp); + } } // Reduce N values concurrently, i.e. suppose N = 2, and there are 4 threads: @@ -47,6 +141,9 @@ __device__ void reduceNValuesInBlock(T *smem, return; } +#if __CUDA_ARCH__ >= 300 + warpReduceBlock(smem, threadVals, numVals, reduceOp, init); +#else // We store each of the N values contiguously, so if N = 2, all values for // the first threadVal for each thread in the block are stored followed by // all of the values for the second threadVal for each thread in the block @@ -102,6 +199,7 @@ __device__ void reduceNValuesInBlock(T *smem, } } } +#endif } // Block-wide reduction in shared memory helper; only threadIdx.x == 0 will diff --git a/lib/THC/THCTensorMode.cuh b/lib/THC/THCTensorMode.cuh index d63ed7a4..e2c30cf1 100644 --- a/lib/THC/THCTensorMode.cuh +++ b/lib/THC/THCTensorMode.cuh @@ -56,6 +56,15 @@ struct ModeUnsignedBoolPair { bool flag; }; +template <> +struct THCWarpUtils { + static __device__ __forceinline__ ModeUnsignedBoolPair shflxor(ModeUnsignedBoolPair val, unsigned int mask) { + val.val = __shfl_xor(val.val, mask); + val.flag = (bool) __shfl_xor((int) val.flag, mask); + return val; + } +}; + // In the kernel below, we have a common pattern of reducing (unsigned int, unsigned int) // pairs of data struct ModeUnsignedPair { @@ -63,6 +72,15 @@ struct ModeUnsignedPair { unsigned int index; }; +template <> +struct THCWarpUtils { + static __device__ __forceinline__ ModeUnsignedPair shflxor(ModeUnsignedPair val, unsigned int mask) { + val.val = __shfl_xor(val.val, mask); + val.index = __shfl_xor(val.index, mask); + return val; + } +}; + template struct MaxReduceOp { __host__ __device__ inline T operator()(const T& a, const T& b) {