Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid cartesian iteration where possible. #464

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,53 @@ end
bc′ = Broadcast.preprocess(dest, bc)

# grid-stride kernel
function broadcast_kernel(ctx, dest, bc′, nelem)
i = 0
while i < nelem
i += 1
I = @cartesianidx(dest, i)
@inbounds dest[I] = bc′[I]
function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is
j = 0
while j < nelem
j += 1

i = @linearidx(dest, j)

# cartesian indexing is slow, so avoid it if possible
if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian)
# this performs an integer division, which is expensive. to make it possible
# for the compiler to optimize it away, we put the iterator in the type
# domain so that the indices are available at compile time. note that LLVM
# only seems to replace pow2 divisions (with bitshifts), but other back-ends
# may be smarter and replace arbitrary divisions by bit operations.
#
# also see maleadt/StaticCartesian.jl, which implements this in Julia,
# but does not result in an additional speed-up on tested back-ends.
#
# in addition, we use @inbounds to avoid bounds checks, but we also need to
# inform the compiler about the bounds that we are assuming. this is done
# using the assume intrinsic, and in case of Metal yields a 8x speed-up.
assume(1 <= i <= length(Is))
I = @inbounds Is[i]
end

val = if isa(IndexStyle(bc′), IndexCartesian)
@inbounds bc′[I]
else
@inbounds bc′[i]
end

if isa(IndexStyle(dest), IndexCartesian)
@inbounds dest[I] = val
else
@inbounds dest[i] = val
end
end
return
end
elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
Is = CartesianIndices(dest)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)

return dest
Expand Down
2 changes: 1 addition & 1 deletion src/host/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

function Base.clamp!(A::AnyGPUArray, low, high)
gpu_call(A, low, high) do ctx, A, low, high
I = @cartesianidx A
I = @linearidx A
A[I] = clamp(A[I], low, high)
return
end
Expand Down