Skip to content

Commit

Permalink
Revert "remocing heuristic"
Browse files Browse the repository at this point in the history
This reverts commit 0c7e26b.
  • Loading branch information
leios committed Sep 16, 2024
1 parent 00c8dd4 commit 52db290
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Reexport
@reexport using GPUArraysCore

## executed on-device
include("device/execution.jl")
include("device/abstractarray.jl")

using KernelAbstractions
Expand Down
39 changes: 39 additions & 0 deletions src/device/execution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# kernel execution

# how many threads and blocks `kernel` needs to be launched with, passing arguments `args`,
# to fully saturate the GPU. `elements` indicates the number of elements that needs to be
# processed, while `elements_per_threads` indicates the number of elements this kernel can
# process (i.e. if it's a grid-stride kernel, or 1 if otherwise).
#
# this heuristic should be specialized for the back-end, ideally using an API for maximizing
# the occupancy of the launch configuration (like CUDA's occupancy API).
function launch_heuristic(backend::B, kernel, args...;
elements::Int,
elements_per_thread::Int) where B <: Backend
return (threads=256, blocks=32)
end

# determine how many threads and blocks to actually launch given upper limits.
# returns a tuple of blocks, threads, and elements_per_thread (which is always 1
# unless specified that the kernel can handle a number of elements per thread)
function launch_configuration(backend::B, heuristic;
elements::Int,
elements_per_thread::Int) where B <: Backend
threads = clamp(elements, 1, heuristic.threads)
blocks = max(cld(elements, threads), 1)

if elements_per_thread > 1 && blocks > heuristic.blocks
# we want to launch more blocks than required, so prefer a grid-stride loop instead
## try to stick to the number of blocks that the heuristic suggested
blocks = heuristic.blocks
nelem = cld(elements, blocks*threads)
## only bump the number of blocks if we really need to
if nelem > elements_per_thread
nelem = elements_per_thread
blocks = cld(elements, nelem*threads)
end
(; threads, blocks, elements_per_thread=nelem)
else
(; threads, blocks, elements_per_thread=1)
end
end
26 changes: 20 additions & 6 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,28 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
end

# grid-stride kernel
@kernel function map_kernel(dest, bc)
j = @index(Global, Linear)
@inbounds dest[j] = bc[j]
@kernel function map_kernel(dest, bc, nelem, common_length)

j = 0
J = @index(Global, Linear)
for i in 1:nelem
j += 1
if j <= common_length

J_c = CartesianIndices(axes(bc))[(J-1)*nelem + j]
@inbounds dest[J_c] = bc[J_c]
end
end
end

elements = common_length
elements_per_thread = typemax(Int)
kernel = map_kernel(get_backend(dest))
config = KernelAbstractions.launch_config(kernel, common_length, nothing)
kernel(dest, bc; ndrange = config[1], workgroupsize = config[2])
heuristic = launch_heuristic(get_backend(dest), kernel, dest, bc, 1,
common_length; elements, elements_per_thread)
config = launch_configuration(get_backend(dest), heuristic;
elements, elements_per_thread)
kernel(dest, bc, config.elements_per_thread,
common_length; ndrange = config.threads)

if eltype(dest) <: BrokenBroadcast
throw(ArgumentError("Map operation resulting in $(eltype(eltype(dest))) is not GPU compatible"))
Expand Down

0 comments on commit 52db290

Please sign in to comment.