Skip to content

Commit

Permalink
Merge pull request #144 from omlins/gpu-ext
Browse files Browse the repository at this point in the history
Use extensions for GPU dependencies
  • Loading branch information
omlins authored Mar 14, 2024
2 parents 62eff06 + 3c44361 commit fd7a162
Show file tree
Hide file tree
Showing 33 changed files with 240 additions and 169 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ authors = ["Samuel Omlin", "Ludovic Räss"]
version = "0.11.1"

[deps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CellArrays = "d35fcfd7-7af4-4c67-b1aa-d78070614af4"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
ParallelStencil_AMDGPUExt = "AMDGPU"
ParallelStencil_CUDAExt = "CUDA"
ParallelStencil_EnzymeExt = "Enzyme"

[compat]
Expand All @@ -31,4 +33,4 @@ TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "TOML", "Enzyme"]
test = ["Test", "TOML", "AMDGPU", "CUDA", "Enzyme"]
4 changes: 4 additions & 0 deletions ext/ParallelStencil_AMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
module ParallelStencil_AMDGPUExt
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "AMDGPUExt", "shared.jl"))
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "AMDGPUExt", "allocators.jl"))
end
4 changes: 4 additions & 0 deletions ext/ParallelStencil_CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
module ParallelStencil_CUDAExt
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "CUDAExt", "shared.jl"))
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "CUDAExt", "allocators.jl"))
end
29 changes: 29 additions & 0 deletions src/ParallelKernel/AMDGPUExt/allocators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## RUNTIME ALLOCATOR FUNCTIONS

ParallelStencil.ParallelKernel.zeros_amdgpu(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_amdgpu(T); AMDGPU.zeros(T, args...)) # (blocklength is ignored if neither celldims nor celltype is set)
ParallelStencil.ParallelKernel.ones_amdgpu(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_amdgpu(T); AMDGPU.ones(T, args...)) # ...
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = ROCArray(rand_cpu(T, blocklength, args...)) # ...
ParallelStencil.ParallelKernel.falses_amdgpu(::Type{T}, blocklength, args...) where {T<:Bool} = AMDGPU.falses(args...) # ...
ParallelStencil.ParallelKernel.trues_amdgpu(::Type{T}, blocklength, args...) where {T<:Bool} = AMDGPU.trues(args...) # ...
ParallelStencil.ParallelKernel.fill_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = ROCArray(fill_cpu(T, blocklength, args...)) # ...

ParallelStencil.ParallelKernel.zeros_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_amdgpu(T); fill_amdgpu(T, blocklength, 0, args...))
ParallelStencil.ParallelKernel.ones_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_amdgpu(T); fill_amdgpu(T, blocklength, 1, args...))
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_amdgpu(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, AMDGPU.ROCArray{eltype(T),3}}(AMDGPU.ROCArray(Base.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen)))), dims)) # TODO: use AMDGPU.rand! instead of AMDGPU.rand once it supports Enums: rand_amdgpu(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_amdgpu(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, AMDGPU.ROCArray{eltype(T),3}}(AMDGPU.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims))
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_amdgpu(T, blocklength, dims)
ParallelStencil.ParallelKernel.falses_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_amdgpu(T, blocklength, false, args...)
ParallelStencil.ParallelKernel.trues_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_amdgpu(T, blocklength, true, args...)

function ParallelStencil.ParallelKernel.fill_amdgpu(::Type{T}, ::Val{B}, x, args...) where {T <: Union{SArray,FieldArray}, B}
if (!(eltype(x) <: Number) || (eltype(x) == Bool)) && (eltype(x) != eltype(T)) @ArgumentError("fill: the (element) type of argument 'x' is not a normal number type ($(eltype(x))), but does not match the obtained (default) 'eltype' ($(eltype(T))); automatic conversion to $(eltype(T)) is therefore not attempted. Set the keyword argument 'eltype' accordingly to the element type of 'x' or pass an 'x' of a different (element) type.") end
check_datatype_amdgpu(T, Bool, Enum)
if (length(x) == 1) cell = convert(T, fill(convert(eltype(T), x), size(T)))
elseif (length(x) == length(T)) cell = convert(T, x)
else @ArgumentError("fill: argument 'x' contains the wrong number of elements ($(length(x))). It must be a scalar or contain the number of elements defined by 'celldims'.")
end
return CellArrays.fill!(ROCCellArray{T,B}(undef, args...), cell)
end

ParallelStencil.ParallelKernel.fill_amdgpu!(A, x) = AMDGPU.fill!(A, construct_cell(A, x))

check_datatype_amdgpu(args...) = check_datatype(args..., INT_AMDGPU)
18 changes: 18 additions & 0 deletions src/ParallelKernel/AMDGPUExt/defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
const ERRMSG_AMDGPUEXT_NOT_LOADED = "the AMDGPU extension was not loaded. Make sure to import AMDGPU before ParallelStencil."


# shared.jl

function get_priority_rocstream end
function get_rocstream end


# allocators.jl

zeros_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
ones_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
rand_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
falses_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
trues_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
fill_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
fill_amdgpu!(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED)
31 changes: 31 additions & 0 deletions src/ParallelKernel/AMDGPUExt/shared.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ParallelStencil
import ParallelStencil.ParallelKernel: INT_AMDGPU, rand_cpu, fill_cpu, construct_cell, check_datatype, rand_amdgpu, fill_amdgpu
using ParallelStencil.ParallelKernel.Exceptions
using AMDGPU, CellArrays, StaticArrays
@define_ROCCellArray


## FUNCTIONS TO CHECK EXTENSIONS SUPPORT

ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_AMDGPUExt}) = true


## FUNCTIONS TO GET CREATE AND MANAGE AMDGPU QUEUES AND "ROCSTREAMS"

ParallelStencil.ParallelKernel.get_priority_rocstream(arg...) = get_priority_rocstream(arg...)
ParallelStencil.ParallelKernel.get_rocstream(arg...) = get_rocstream(arg...)
let
global get_priority_rocstream, get_rocstream
priority_rocstreams = Array{AMDGPU.HIPStream}(undef, 0)
rocstreams = Array{AMDGPU.HIPStream}(undef, 0)

function get_priority_rocstream(id::Integer)
while (id > length(priority_rocstreams)) push!(priority_rocstreams, AMDGPU.HIPStream(:high)) end
return priority_rocstreams[id]
end

function get_rocstream(id::Integer)
while (id > length(rocstreams)) push!(rocstreams, AMDGPU.HIPStream(:low)) end
return rocstreams[id]
end
end
29 changes: 29 additions & 0 deletions src/ParallelKernel/CUDAExt/allocators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## RUNTIME ALLOCATOR FUNCTIONS

ParallelStencil.ParallelKernel.zeros_cuda(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_cuda(T); CUDA.zeros(T, args...)) # (blocklength is ignored if neither celldims nor celltype is set)
ParallelStencil.ParallelKernel.ones_cuda(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_cuda(T); CUDA.ones(T, args...)) # ...
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = CuArray(rand_cpu(T, blocklength, args...)) # ...
ParallelStencil.ParallelKernel.falses_cuda(::Type{T}, blocklength, args...) where {T<:Bool} = CUDA.zeros(Bool, args...) # ...
ParallelStencil.ParallelKernel.trues_cuda(::Type{T}, blocklength, args...) where {T<:Bool} = CUDA.ones(Bool, args...) # ...
ParallelStencil.ParallelKernel.fill_cuda(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = CuArray(fill_cpu(T, blocklength, args...)) # ...

ParallelStencil.ParallelKernel.zeros_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 0, args...))
ParallelStencil.ParallelKernel.ones_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 1, args...))
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_cuda(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, CUDA.CuArray{eltype(T),3}}(CUDA.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims))
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_cuda(T, blocklength, dims)
ParallelStencil.ParallelKernel.falses_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, false, args...)
ParallelStencil.ParallelKernel.trues_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, true, args...)

function ParallelStencil.ParallelKernel.fill_cuda(::Type{T}, ::Val{B}, x, args...) where {T <: Union{SArray,FieldArray}, B}
if (!(eltype(x) <: Number) || (eltype(x) == Bool)) && (eltype(x) != eltype(T)) @ArgumentError("fill: the (element) type of argument 'x' is not a normal number type ($(eltype(x))), but does not match the obtained (default) 'eltype' ($(eltype(T))); automatic conversion to $(eltype(T)) is therefore not attempted. Set the keyword argument 'eltype' accordingly to the element type of 'x' or pass an 'x' of a different (element) type.") end
check_datatype_cuda(T, Bool, Enum)
if (length(x) == 1) cell = convert(T, fill(convert(eltype(T), x), size(T)))
elseif (length(x) == length(T)) cell = convert(T, x)
else @ArgumentError("fill: argument 'x' contains the wrong number of elements ($(length(x))). It must be a scalar or contain the number of elements defined by 'celldims'.")
end
return CellArrays.fill!(CuCellArray{T,B}(undef, args...), cell)
end

ParallelStencil.ParallelKernel.fill_cuda!(A, x) = CUDA.fill!(A, construct_cell(A, x))

check_datatype_cuda(args...) = check_datatype(args..., INT_CUDA)
18 changes: 18 additions & 0 deletions src/ParallelKernel/CUDAExt/defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
const ERRMSG_CUDAEXT_NOT_LOADED = "the CUDA extension was not loaded. Make sure to import CUDA before ParallelStencil."


# shared.jl

function get_priority_custream end
function get_custream end


# allocators.jl

zeros_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
ones_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
rand_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
falses_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
trues_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
fill_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
fill_cuda!(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED)
31 changes: 31 additions & 0 deletions src/ParallelKernel/CUDAExt/shared.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ParallelStencil
import ParallelStencil.ParallelKernel: INT_CUDA, rand_cpu, fill_cpu, construct_cell, check_datatype, rand_cuda, fill_cuda
using ParallelStencil.ParallelKernel.Exceptions
using CUDA, CellArrays, StaticArrays
@define_CuCellArray


## FUNCTIONS TO CHECK EXTENSIONS SUPPORT

ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_CUDAExt}) = true


## FUNCTIONS TO GET CREATE AND MANAGE CUDA STREAMS

ParallelStencil.ParallelKernel.get_priority_custream(arg...) = get_priority_custream(arg...)
ParallelStencil.ParallelKernel.get_custream(arg...) = get_custream(arg...)
let
global get_priority_custream, get_custream
priority_custreams = Array{CuStream}(undef, 0)
custreams = Array{CuStream}(undef, 0)

function get_priority_custream(id::Integer)
while (id > length(priority_custreams)) push!(priority_custreams, CuStream(; flags=CUDA.STREAM_NON_BLOCKING, priority=CUDA.priority_range()[end])) end # CUDA.priority_range()[end] is max priority. # NOTE: priority_range cannot be called outside the function as only at runtime sure that CUDA is functional.
return priority_custreams[id]
end

function get_custream(id::Integer)
while (id > length(custreams)) push!(custreams, CuStream(; flags=CUDA.STREAM_NON_BLOCKING, priority=CUDA.priority_range()[1])) end # CUDA.priority_range()[1] is min priority. # NOTE: priority_range cannot be called outside the function as only at runtime sure that CUDA is functional.
return custreams[id]
end
end
8 changes: 4 additions & 4 deletions src/ParallelKernel/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Expands to: `NTuple{N_tuple, Data.Cell{numbertype, S}}` | `NamedTuple{names, NTu
function Data_cuda(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
import Base, CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_CuCellArray
export CuCellArray
const Index = $indextype
Expand All @@ -155,7 +155,7 @@ function Data_cuda(numbertype::DataType, indextype::DataType)
end)
else
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
import Base, CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_CuCellArray
export CuCellArray
const Index = $indextype
Expand All @@ -179,7 +179,7 @@ end
function Data_amdgpu(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
import Base, AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_ROCCellArray
export ROCCellArray
const Index = $indextype
Expand All @@ -192,7 +192,7 @@ function Data_amdgpu(numbertype::DataType, indextype::DataType)
end)
else
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
import Base, AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_ROCCellArray
export ROCCellArray
const Index = $indextype
Expand Down
6 changes: 3 additions & 3 deletions src/ParallelKernel/EnzymeExt/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ To see a description of a function type `?<functionname>`.
module AD
using ..Exceptions

const ERRMSG_EXTENSION_LOAD_ERROR = "AD: the Enzyme extension was not loaded. Make sure to import Enzyme before ParallelStencil."
const ERRMSG_ENZYMEEXT_NOT_LOADED = "AD: the Enzyme extension was not loaded. Make sure to import Enzyme before ParallelStencil."

init_AD(args...) = return # NOTE: a call will be triggered from @init_parallel_kernel, but it will do nothing if the extension is not loaded. Methods are to be defined in the AD extension modules.
autodiff_deferred!(args...) = @NotLoadedError(ERRMSG_EXTENSION_NOT_LOADED)
autodiff_deferred_thunk!(args...) = @NotLoadedError(ERRMSG_EXTENSION_NOT_LOADED)
autodiff_deferred!(args...) = @NotLoadedError(ERRMSG_ENZYMEEXT_NOT_LOADED)
autodiff_deferred_thunk!(args...) = @NotLoadedError(ERRMSG_ENZYMEEXT_NOT_LOADED)

export autodiff_deferred!, autodiff_deferred_thunk!

Expand Down
5 changes: 5 additions & 0 deletions src/ParallelKernel/EnzymeExt/autodiff_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ function ParallelStencil.ParallelKernel.AD.autodiff_deferred_thunk!(arg, args...
Enzyme.autodiff_deferred_thunk(arg, args...)
return
end


## FUNCTIONS TO CHECK EXTENSIONS SUPPORT

ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_EnzymeExt}) = true
4 changes: 4 additions & 0 deletions src/ParallelKernel/ParallelKernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ using .Exceptions
include(joinpath("EnzymeExt", "AD.jl"));
include("Data.jl");

## Alphabetical include of defaults for extensions
include(joinpath("AMDGPUExt", "defaults.jl"))
include(joinpath("CUDAExt", "defaults.jl"))

## Include of constant parameters, types and syntax sugar shared in ParallelKernel module only
include("shared.jl")

Expand Down
Loading

0 comments on commit fd7a162

Please sign in to comment.