-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Warp Shuffle-based reductions for arch >= 3.0 #750
base: master
Are you sure you want to change the base?
Conversation
lib/THC/THCReduceApplyUtils.cuh
Outdated
|
||
#pragma unroll | ||
for (int i = 0; i < N; ++i) { | ||
threadVals[i] = (threadIdx.x < (blockDim.x / warpSize)) ? smem[lane + (i * warpSize)] : init; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be better to only have the first warp execute this?
lib/THC/THCReduceApplyUtils.cuh
Outdated
__syncthreads(); | ||
|
||
#pragma unroll | ||
for (int i = 0; i < N; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah move the warp == 0 check to wrap this too, no need to have the other warps doing some of this work too
@@ -125,7 +237,7 @@ __device__ T reduceBlockWithNThreadLocalReductions(T *smem, | |||
local = reduceOp(local, next); | |||
} | |||
|
|||
return reduceBlock<T, ReduceOp>(smem, blockDim.x < numVals ? blockDim.x : numVals, local, reduceOp, init); | |||
return reduceBlock<T, ReduceOp>(smem, THCCeilDiv(numVals, N), local, reduceOp, init); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the old or new code here.
Isn't numVals
always less than blockDim.x, because it is the number of threads with N active values? In other words, it has nothing to do with N
, because all numVals
threads have N values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the new code: numVals
is the overall slice size for the reduction. So If numVals
is 512 and N
= 2 then the first 256 threads in the block have valid input values. For reduceBlock
, numVals represents the number of threads whose values should be considered valid. So in the above example, the first 256 threads have values that should be reduced, hence the division.
The old code was doing something incorrect. However, because the local reduction uses init, it still succeeded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but then what about the case where numVals is not a multiple of N?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case the THCCeilDiv
should round it up - if we had 513 values in the above case then the first 257 threads should have valid input values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but then only one of the 2 input values for the 257th thread would then be valid? where's the guarantee that all provided input values are either valid reduction values, or identity values? if they're identity values for the reduction, then we don't really need numVals at all, except possibly to handle tiny reduction cases (smaller than the block size) for slightly greater efficiency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just means its part of the input that we need to consider. The resulting value from the 257th thread is the combination of the input value and a identity value, but the presence of the input value means it must take part in the reduction.
Yes, in theory we could imagine a case where block size > numVals -> this is common in the code for mode, example, where we round up.
…ion, in preparation for warp shuffle
As title. Changes made:
reduceSmemSize
that determines the amount of shared memory required for a reduction. In older architectures, this has not changed. In new architectures, the warp shuffle allows us to use onlywarpSize * sizeof(T)
shared memory for each reductionshfl_xor