Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve GPU performance of multi-patch operator #45

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions ext/MPIRecoKernelAbstractionsExt/MultiPatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,39 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T,
shared[localIdx] = zero(eltype(b))

# First we iterate over the sparse indices
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
# The following code does essentially this:
#tmp = zero(eltype(b))
#@unroll for i = localIdx:grid_stride:N
# tmp += sign * S[patch_row, xss[i, patch], smIdx] * x[xcc[i, patch]]
#end
#shared[localIdx] = tmp
# We first sum in a temp variable, hoping that it is accumulated in a register, since registers are faster than shared memory

# In this variant we further try use multiple registers to do independent sums to have more instruction level parallelism
tmp = @private eltype(b) 8
@unroll for j = 1:8
tmp[j] = zero(eltype(b))
end
@unroll for i = localIdx:grid_stride*8:N
@unroll for j = 1:8
index = i + (j - 1) * grid_stride
if index <= N
tmp[j] = tmp[j] + sign * S[patch_row, xss[index , patch], smIdx] * x[xcc[index , patch]]
end
end
end
@unroll for j = 1:8
shared[localIdx] += tmp[j]
end
@synchronize

# Now we need to reduce the shared memory to get the final result
full_reduction = grid_stride < N
if full_reduction

# For a full reduction we know s = 1024 and can (manually) unroll our loop
localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
@synchronize
# For a full reduction we know s = 512 and can (manually) unroll our loop
#localIdx <= 512 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 512])
#@synchronize
localIdx <= 256 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 256])
@synchronize
localIdx <= 128 && (@inbounds shared[localIdx] = shared[localIdx] + shared[localIdx + 128])
Expand Down Expand Up @@ -90,8 +111,8 @@ function LinearAlgebra.mul!(b::AbstractVector{T}, op::DenseMultiPatchOperator{T,
end
end

kernel = dense_mul!(backend, 1024)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (1024, size(op, 1)))
kernel = dense_mul!(backend, 512)
kernel(b, x, op.S, op.xcc, op.xss, op.sign, Int32(div(op.M, op.nPatches)), op.RowToPatch, op.patchToSMIdx; ndrange = (512, size(op, 1)))
synchronize(backend)
return b
end
Expand Down Expand Up @@ -210,9 +231,11 @@ end
shared = @localmem eltype(energy) grid_stride
shared[localIdx] = zero(eltype(energy))

tmp = zero(eltype(energy))
@unroll for i = localIdx:grid_stride:N
shared[localIdx] = shared[localIdx] + abs2(sign * S[patch_row, xss[i, patch], smIdx])
tmp += abs2(sign * S[patch_row, xss[i, patch], smIdx])
end
shared[localIdx] = tmp
@synchronize

@private s = div(min(grid_stride, N), Int32(2))
Expand Down
Loading