Skip to content

Commit

Permalink
Merge pull request #147 from omlins/ext-fixes
Browse files Browse the repository at this point in the history
Minor improvements for gpu extensions
  • Loading branch information
omlins authored Mar 22, 2024
2 parents d2008e0 + 1ddbd35 commit d20832d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 48 deletions.
89 changes: 54 additions & 35 deletions src/ParallelKernel/Data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,30 @@ Expands to: `NTuple{N_tuple, Data.Cell{numbertype, S}}` | `NamedTuple{names, NTu
This datatype is not intended for explicit manual usage. [`@parallel`](@ref) and [`@parallel_indices`](@ref) convert CUDA.CuArray and AMDGPU.ROCArray automatically to CUDA.CuDeviceArray and AMDGPU.ROCDeviceArray in kernels when required.
"""

function Data_cuda(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
function Data_cuda(modulename::Symbol, numbertype::DataType, indextype::DataType)
Data_module = if (numbertype == NUMBERTYPE_NONE)
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_CuCellArray
export CuCellArray
# TODO: the constructors defined by CellArrays.@define_CuCellArray lead to pre-compilation issues due to a bug in Julia. We therefore only create the type alias here for now.
const CuCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,CUDA.CuArray{T_elem,CellArrays._N}}
# CellArrays.@define_CuCellArray
# export CuCellArray
const Index = $indextype
const Array{T, N} = CUDA.CuArray{T, N}
const DeviceArray{T, N} = CUDA.CuDeviceArray{T, N}
const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const DeviceCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const CellArray{T_elem, N, B} = CuCellArray{<:Cell{T_elem},N,B,T_elem}
const DeviceCellArray{T_elem, N, B} = CellArrays.CellArray{<:DeviceCell{T_elem},N,B,<:CUDA.CuDeviceArray{T_elem,CellArrays._N}}
$(create_shared_exprs(numbertype, indextype))
end)
else
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, CUDA, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_CuCellArray
export CuCellArray
# TODO: the constructors defined by CellArrays.@define_CuCellArray lead to pre-compilation issues due to a bug in Julia. We therefore only create the type alias here for now.
const CuCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,CUDA.CuArray{T_elem,CellArrays._N}}
# CellArrays.@define_CuCellArray
# export CuCellArray
const Index = $indextype
const Number = $numbertype
const Array{N} = CUDA.CuArray{$numbertype, N}
Expand All @@ -172,29 +177,36 @@ function Data_cuda(numbertype::DataType, indextype::DataType)
const DeviceTCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const TCellArray{T_elem, N, B} = CuCellArray{<:TCell{T_elem},N,B,T_elem}
const DeviceTCellArray{T_elem, N, B} = CellArrays.CellArray{<:DeviceTCell{T_elem},N,B,<:CUDA.CuDeviceArray{T_elem,CellArrays._N}}
$(create_shared_exprs(numbertype, indextype))
end)
end
return prewalk(rmlines, flatten(Data_module))
end

function Data_amdgpu(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
function Data_amdgpu(modulename::Symbol, numbertype::DataType, indextype::DataType)
Data_module = if (numbertype == NUMBERTYPE_NONE)
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_ROCCellArray
export ROCCellArray
# TODO: the constructors defined by CellArrays.@define_ROCCellArray lead to pre-compilation issues due to a bug in Julia. We therefore only create the type alias here for now.
const ROCCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,AMDGPU.ROCArray{T_elem,CellArrays._N}}
# CellArrays.@define_ROCCellArray
# export ROCCellArray
const Index = $indextype
const Array{T, N} = AMDGPU.ROCArray{T, N}
const DeviceArray{T, N} = AMDGPU.ROCDeviceArray{T, N}
const Cell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const DeviceCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const CellArray{T_elem, N, B} = ROCCellArray{<:Cell{T_elem},N,B,T_elem}
const DeviceCellArray{T_elem, N, B} = CellArrays.CellArray{<:DeviceCell{T_elem},N,B,<:AMDGPU.ROCDeviceArray{T_elem,CellArrays._N}}
$(create_shared_exprs(numbertype, indextype))
end)
else
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, AMDGPU, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
CellArrays.@define_ROCCellArray
export ROCCellArray
# TODO: the constructors defined by CellArrays.@define_ROCCellArray lead to pre-compilation issues due to a bug in Julia. We therefore only create the type alias here for now.
const ROCCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,AMDGPU.ROCArray{T_elem,CellArrays._N}}
# CellArrays.@define_ROCCellArray
# export ROCCellArray
const Index = $indextype
const Number = $numbertype
const Array{N} = AMDGPU.ROCArray{$numbertype, N}
Expand All @@ -209,13 +221,15 @@ function Data_amdgpu(numbertype::DataType, indextype::DataType)
const DeviceTCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const TCellArray{T_elem, N, B} = ROCCellArray{<:TCell{T_elem},N,B,T_elem}
const DeviceTCellArray{T_elem, N, B} = CellArrays.CellArray{<:DeviceTCell{T_elem},N,B,<:AMDGPU.ROCDeviceArray{T_elem,CellArrays._N}}
$(create_shared_exprs(numbertype, indextype))
end)
end
return prewalk(rmlines, flatten(Data_module))
end

function Data_threads(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
function Data_threads(modulename::Symbol, numbertype::DataType, indextype::DataType)
Data_module = if (numbertype == NUMBERTYPE_NONE)
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
const Index = $indextype
const Array{T, N} = Base.Array{T, N}
Expand All @@ -224,9 +238,10 @@ function Data_threads(numbertype::DataType, indextype::DataType)
const DeviceCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const CellArray{T_elem, N, B} = CellArrays.CPUCellArray{<:Cell{T_elem},N,B,T_elem}
const DeviceCellArray{T_elem, N, B} = CellArrays.CPUCellArray{<:DeviceCell{T_elem},N,B,T_elem}
$(create_shared_exprs(numbertype, indextype))
end)
else
:(baremodule Data # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
:(baremodule $modulename # NOTE: there cannot be any newline before 'module Data' or it will create a begin end block and the module creation will fail.
import Base, ParallelStencil.ParallelKernel.CellArrays, ParallelStencil.ParallelKernel.StaticArrays
const Index = $indextype
const Number = $numbertype
Expand All @@ -242,11 +257,13 @@ function Data_threads(numbertype::DataType, indextype::DataType)
const DeviceTCell{T, S} = Union{StaticArrays.SArray{S, T}, StaticArrays.FieldArray{S, T}}
const TCellArray{T_elem, N, B} = CellArrays.CPUCellArray{<:TCell{T_elem},N,B,T_elem}
const DeviceTCellArray{T_elem, N, B} = CellArrays.CPUCellArray{<:DeviceTCell{T_elem},N,B,T_elem}
$(create_shared_exprs(numbertype, indextype))
end)
end
return prewalk(rmlines, flatten(Data_module))
end

function Data_shared(numbertype::DataType, indextype::DataType)
function create_shared_exprs(numbertype::DataType, indextype::DataType)
if numbertype == NUMBERTYPE_NONE
quote
const IndexTuple{N_tuple} = NTuple{N_tuple, Index}
Expand Down Expand Up @@ -276,11 +293,12 @@ function Data_shared(numbertype::DataType, indextype::DataType)
const CellArrayCollection{N_tuple, T_elem, N, B} = Union{CellArrayTuple{N_tuple, T_elem, N, B}, NamedCellArrayTuple{N_tuple, T_elem, N, B}}
const DeviceCellArrayCollection{N_tuple, T_elem, N, B} = Union{DeviceCellArrayTuple{N_tuple, T_elem, N, B}, NamedDeviceCellArrayTuple{N_tuple, T_elem, N, B}}

NamedIndexTuple{}(t::NamedTuple) = Base.map(Data.Index, t)
NamedNumberTuple{}(T, t::NamedTuple) = Base.map(T, t)
NamedArrayTuple{}(T, t::NamedTuple) = Base.map(Data.Array{T}, t)
NamedCellTuple{}(T, t::NamedTuple) = Base.map(Data.Cell{T}, t)
NamedCellArrayTuple{}(T, t::NamedTuple) = Base.map(Data.CellArray{T}, t)
# TODO: the following constructors lead to pre-compilation issues due to a bug in Julia. They are therefore commented out for now.
# NamedIndexTuple{}(t::NamedTuple) = Base.map(Data.Index, t)
# NamedNumberTuple{}(T, t::NamedTuple) = Base.map(T, t)
# NamedArrayTuple{}(T, t::NamedTuple) = Base.map(Data.Array{T}, t)
# NamedCellTuple{}(T, t::NamedTuple) = Base.map(Data.Cell{T}, t)
# NamedCellArrayTuple{}(T, t::NamedTuple) = Base.map(Data.CellArray{T}, t)
end
else
quote
Expand Down Expand Up @@ -332,15 +350,16 @@ function Data_shared(numbertype::DataType, indextype::DataType)
const TCellArrayCollection{N_tuple, T_elem, N, B} = Union{TCellArrayTuple{N_tuple, T_elem, N, B}, NamedTCellArrayTuple{N_tuple, T_elem, N, B}}
const DeviceTCellArrayCollection{N_tuple, T_elem, N, B} = Union{DeviceTCellArrayTuple{N_tuple, T_elem, N, B}, NamedDeviceTCellArrayTuple{N_tuple, T_elem, N, B}}

NamedIndexTuple{}(t::NamedTuple) = Base.map(Data.Index, t)
NamedNumberTuple{}(t::NamedTuple) = Base.map(Data.Number, t)
NamedArrayTuple{}(t::NamedTuple) = Base.map(Data.Array, t)
NamedCellTuple{}(t::NamedTuple) = Base.map(Data.Cell, t)
NamedCellArrayTuple{}(t::NamedTuple) = Base.map(Data.CellArray, t)
NamedTNumberTuple{}(T, t::NamedTuple) = Base.map(T, t)
NamedTArrayTuple{}(T, t::NamedTuple) = Base.map(Data.TArray{T}, t)
NamedTCellTuple{}(T, t::NamedTuple) = Base.map(Data.TCell{T}, t)
NamedTCellArrayTuple{}(T, t::NamedTuple) = Base.map(Data.TCellArray{T}, t)
# TODO: the following constructors lead to pre-compilation issues due to a bug in Julia. They are therefore commented out for now.
# NamedIndexTuple{}(t::NamedTuple) = Base.map(Data.Index, t)
# NamedNumberTuple{}(t::NamedTuple) = Base.map(Data.Number, t)
# NamedArrayTuple{}(t::NamedTuple) = Base.map(Data.Array, t)
# NamedCellTuple{}(t::NamedTuple) = Base.map(Data.Cell, t)
# NamedCellArrayTuple{}(t::NamedTuple) = Base.map(Data.CellArray, t)
# NamedTNumberTuple{}(T, t::NamedTuple) = Base.map(T, t)
# NamedTArrayTuple{}(T, t::NamedTuple) = Base.map(Data.TArray{T}, t)
# NamedTCellTuple{}(T, t::NamedTuple) = Base.map(Data.TCell{T}, t)
# NamedTCellArrayTuple{}(T, t::NamedTuple) = Base.map(Data.TCellArray{T}, t)
end
end
end
Expand Down
15 changes: 6 additions & 9 deletions src/ParallelKernel/init_parallel_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,20 @@ macro init_parallel_kernel(args...)
end

function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataType, inbounds::Bool; datadoc_call=:())
modulename = :Data
if package == PKG_CUDA
if (!CUDA_IS_INSTALLED) @NotInstalledError("CUDA was selected as package for parallelization, but CUDA.jl is not installed. CUDA functionality is provided with an extension of ParallelStencil and CUDA.jl needs therefore to be installed independently.") end
if (!is_installed("CUDA")) @NotInstalledError("CUDA was selected as package for parallelization, but CUDA.jl is not installed. CUDA functionality is provided with an extension of ParallelStencil and CUDA.jl needs therefore to be installed independently.") end
indextype = INT_CUDA
data_module = Data_cuda(numbertype, indextype)
data_module_shared = Data_shared(numbertype, indextype)
data_module = Data_cuda(modulename, numbertype, indextype)
pkg_import_cmd = :(import CUDA)
elseif package == PKG_AMDGPU
if (!AMDGPU_IS_INSTALLED) @NotInstalledError("AMDGPU was selected as package for parallelization, but AMDGPU.jl is not installed. AMDGPU functionality is provided with an extension of ParallelStencil and AMDGPU.jl needs therefore to be installed independently.") end
if (!is_installed("AMDGPU")) @NotInstalledError("AMDGPU was selected as package for parallelization, but AMDGPU.jl is not installed. AMDGPU functionality is provided with an extension of ParallelStencil and AMDGPU.jl needs therefore to be installed independently.") end
indextype = INT_AMDGPU
data_module = Data_amdgpu(numbertype, indextype)
data_module_shared = Data_shared(numbertype, indextype)
data_module = Data_amdgpu(modulename, numbertype, indextype)
pkg_import_cmd = :(import AMDGPU)
elseif package == PKG_THREADS
indextype = INT_THREADS
data_module = Data_threads(numbertype, indextype)
data_module_shared = Data_shared(numbertype, indextype)
data_module = Data_threads(modulename, numbertype, indextype)
pkg_import_cmd = :()
end
ad_init_cmd = :(ParallelStencil.ParallelKernel.AD.init_AD(ParallelStencil.ParallelKernel.PKG_THREADS))
Expand All @@ -53,7 +51,6 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
end
@eval(caller, $pkg_import_cmd)
@eval(caller, $data_module)
@eval(caller.Data, $data_module_shared)
@eval(caller, $datadoc_call)
elseif isdefined(caller, :Data) && isdefined(caller.Data, :DeviceArray)
if !isinteractive() @warn "Module Data from previous module initialization found in caller module ($caller); module Data not created. Note: this warning is only shown in non-interactive mode." end
Expand Down
6 changes: 2 additions & 4 deletions src/ParallelKernel/shared.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using CellArrays, StaticArrays, MacroTools
import MacroTools: postwalk, splitdef, combinedef, isexpr, unblock # NOTE: inexpr_walk used instead of MacroTools.inexpr
import MacroTools: postwalk, splitdef, combinedef, isexpr, unblock, flatten, rmlines, prewalk # NOTE: inexpr_walk used instead of MacroTools.inexpr


## CONSTANTS AND TYPES (and the macros wrapping them)
Expand All @@ -10,8 +10,6 @@ gensym_world(tag::String, generator::Module) = gensym(string(tag, GENSYM_SEPARAT
gensym_world(tag::Symbol, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))
gensym_world(tag::Expr, generator::Module) = gensym(string(tag, GENSYM_SEPARATOR, generator))

const CUDA_IS_INSTALLED = (Base.find_package("CUDA")!==nothing)
const AMDGPU_IS_INSTALLED = (Base.find_package("AMDGPU")!==nothing)
const PKG_CUDA = :CUDA
const PKG_AMDGPU = :AMDGPU
const PKG_THREADS = :Threads
Expand Down Expand Up @@ -68,10 +66,10 @@ macro ranges() esc(RANGES_VARNAME) end
macro rangelengths() esc(:(($(RANGELENGTHS_VARNAMES...),))) end



## FUNCTIONS TO CHECK EXTENSIONS SUPPORT

is_loaded(arg) = false
is_installed(package::String) = (Base.find_package(package)!==nothing)


## FUNCTIONS TO DEAL WITH KERNEL DEFINITIONS: SIGNATURES, BODY AND RETURN STATEMENT
Expand Down

0 comments on commit d20832d

Please sign in to comment.