-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from omlins/gpu-ext
Use extensions for GPU dependencies
- Loading branch information
Showing
33 changed files
with
240 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
module ParallelStencil_AMDGPUExt | ||
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "AMDGPUExt", "shared.jl")) | ||
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "AMDGPUExt", "allocators.jl")) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
module ParallelStencil_CUDAExt | ||
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "CUDAExt", "shared.jl")) | ||
include(joinpath(@__DIR__, "..", "src", "ParallelKernel", "CUDAExt", "allocators.jl")) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
## RUNTIME ALLOCATOR FUNCTIONS | ||
|
||
ParallelStencil.ParallelKernel.zeros_amdgpu(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_amdgpu(T); AMDGPU.zeros(T, args...)) # (blocklength is ignored if neither celldims nor celltype is set) | ||
ParallelStencil.ParallelKernel.ones_amdgpu(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_amdgpu(T); AMDGPU.ones(T, args...)) # ... | ||
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = ROCArray(rand_cpu(T, blocklength, args...)) # ... | ||
ParallelStencil.ParallelKernel.falses_amdgpu(::Type{T}, blocklength, args...) where {T<:Bool} = AMDGPU.falses(args...) # ... | ||
ParallelStencil.ParallelKernel.trues_amdgpu(::Type{T}, blocklength, args...) where {T<:Bool} = AMDGPU.trues(args...) # ... | ||
ParallelStencil.ParallelKernel.fill_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = ROCArray(fill_cpu(T, blocklength, args...)) # ... | ||
|
||
ParallelStencil.ParallelKernel.zeros_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_amdgpu(T); fill_amdgpu(T, blocklength, 0, args...)) | ||
ParallelStencil.ParallelKernel.ones_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_amdgpu(T); fill_amdgpu(T, blocklength, 1, args...)) | ||
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_amdgpu(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, AMDGPU.ROCArray{eltype(T),3}}(AMDGPU.ROCArray(Base.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen)))), dims)) # TODO: use AMDGPU.rand! instead of AMDGPU.rand once it supports Enums: rand_amdgpu(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_amdgpu(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, AMDGPU.ROCArray{eltype(T),3}}(AMDGPU.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims)) | ||
ParallelStencil.ParallelKernel.rand_amdgpu(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_amdgpu(T, blocklength, dims) | ||
ParallelStencil.ParallelKernel.falses_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_amdgpu(T, blocklength, false, args...) | ||
ParallelStencil.ParallelKernel.trues_amdgpu(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_amdgpu(T, blocklength, true, args...) | ||
|
||
function ParallelStencil.ParallelKernel.fill_amdgpu(::Type{T}, ::Val{B}, x, args...) where {T <: Union{SArray,FieldArray}, B} | ||
if (!(eltype(x) <: Number) || (eltype(x) == Bool)) && (eltype(x) != eltype(T)) @ArgumentError("fill: the (element) type of argument 'x' is not a normal number type ($(eltype(x))), but does not match the obtained (default) 'eltype' ($(eltype(T))); automatic conversion to $(eltype(T)) is therefore not attempted. Set the keyword argument 'eltype' accordingly to the element type of 'x' or pass an 'x' of a different (element) type.") end | ||
check_datatype_amdgpu(T, Bool, Enum) | ||
if (length(x) == 1) cell = convert(T, fill(convert(eltype(T), x), size(T))) | ||
elseif (length(x) == length(T)) cell = convert(T, x) | ||
else @ArgumentError("fill: argument 'x' contains the wrong number of elements ($(length(x))). It must be a scalar or contain the number of elements defined by 'celldims'.") | ||
end | ||
return CellArrays.fill!(ROCCellArray{T,B}(undef, args...), cell) | ||
end | ||
|
||
ParallelStencil.ParallelKernel.fill_amdgpu!(A, x) = AMDGPU.fill!(A, construct_cell(A, x)) | ||
|
||
check_datatype_amdgpu(args...) = check_datatype(args..., INT_AMDGPU) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
const ERRMSG_AMDGPUEXT_NOT_LOADED = "the AMDGPU extension was not loaded. Make sure to import AMDGPU before ParallelStencil." | ||
|
||
|
||
# shared.jl | ||
|
||
function get_priority_rocstream end | ||
function get_rocstream end | ||
|
||
|
||
# allocators.jl | ||
|
||
zeros_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
ones_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
rand_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
falses_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
trues_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
fill_amdgpu(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) | ||
fill_amdgpu!(arg...) = @NotLoadedError(ERRMSG_AMDGPUEXT_NOT_LOADED) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import ParallelStencil | ||
import ParallelStencil.ParallelKernel: INT_AMDGPU, rand_cpu, fill_cpu, construct_cell, check_datatype, rand_amdgpu, fill_amdgpu | ||
using ParallelStencil.ParallelKernel.Exceptions | ||
using AMDGPU, CellArrays, StaticArrays | ||
@define_ROCCellArray | ||
|
||
|
||
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT | ||
|
||
ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_AMDGPUExt}) = true | ||
|
||
|
||
## FUNCTIONS TO GET CREATE AND MANAGE AMDGPU QUEUES AND "ROCSTREAMS" | ||
|
||
ParallelStencil.ParallelKernel.get_priority_rocstream(arg...) = get_priority_rocstream(arg...) | ||
ParallelStencil.ParallelKernel.get_rocstream(arg...) = get_rocstream(arg...) | ||
let | ||
global get_priority_rocstream, get_rocstream | ||
priority_rocstreams = Array{AMDGPU.HIPStream}(undef, 0) | ||
rocstreams = Array{AMDGPU.HIPStream}(undef, 0) | ||
|
||
function get_priority_rocstream(id::Integer) | ||
while (id > length(priority_rocstreams)) push!(priority_rocstreams, AMDGPU.HIPStream(:high)) end | ||
return priority_rocstreams[id] | ||
end | ||
|
||
function get_rocstream(id::Integer) | ||
while (id > length(rocstreams)) push!(rocstreams, AMDGPU.HIPStream(:low)) end | ||
return rocstreams[id] | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
## RUNTIME ALLOCATOR FUNCTIONS | ||
|
||
ParallelStencil.ParallelKernel.zeros_cuda(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_cuda(T); CUDA.zeros(T, args...)) # (blocklength is ignored if neither celldims nor celltype is set) | ||
ParallelStencil.ParallelKernel.ones_cuda(::Type{T}, blocklength, args...) where {T<:Number} = (check_datatype_cuda(T); CUDA.ones(T, args...)) # ... | ||
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = CuArray(rand_cpu(T, blocklength, args...)) # ... | ||
ParallelStencil.ParallelKernel.falses_cuda(::Type{T}, blocklength, args...) where {T<:Bool} = CUDA.zeros(Bool, args...) # ... | ||
ParallelStencil.ParallelKernel.trues_cuda(::Type{T}, blocklength, args...) where {T<:Bool} = CUDA.ones(Bool, args...) # ... | ||
ParallelStencil.ParallelKernel.fill_cuda(::Type{T}, blocklength, args...) where {T<:Union{Number,Enum}} = CuArray(fill_cpu(T, blocklength, args...)) # ... | ||
|
||
ParallelStencil.ParallelKernel.zeros_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 0, args...)) | ||
ParallelStencil.ParallelKernel.ones_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = (check_datatype_cuda(T); fill_cuda(T, blocklength, 1, args...)) | ||
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, ::Val{B}, dims) where {T<:Union{SArray,FieldArray}, B} = (check_datatype_cuda(T, Bool, Enum); blocklen = (B == 0) ? prod(dims) : B; CellArray{T,length(dims),B, CUDA.CuArray{eltype(T),3}}(CUDA.rand(eltype(T), blocklen, prod(size(T)), ceil(Int,prod(dims)/(blocklen))), dims)) | ||
ParallelStencil.ParallelKernel.rand_cuda(::Type{T}, blocklength, dims...) where {T<:Union{SArray,FieldArray}} = rand_cuda(T, blocklength, dims) | ||
ParallelStencil.ParallelKernel.falses_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, false, args...) | ||
ParallelStencil.ParallelKernel.trues_cuda(::Type{T}, blocklength, args...) where {T<:Union{SArray,FieldArray}} = fill_cuda(T, blocklength, true, args...) | ||
|
||
function ParallelStencil.ParallelKernel.fill_cuda(::Type{T}, ::Val{B}, x, args...) where {T <: Union{SArray,FieldArray}, B} | ||
if (!(eltype(x) <: Number) || (eltype(x) == Bool)) && (eltype(x) != eltype(T)) @ArgumentError("fill: the (element) type of argument 'x' is not a normal number type ($(eltype(x))), but does not match the obtained (default) 'eltype' ($(eltype(T))); automatic conversion to $(eltype(T)) is therefore not attempted. Set the keyword argument 'eltype' accordingly to the element type of 'x' or pass an 'x' of a different (element) type.") end | ||
check_datatype_cuda(T, Bool, Enum) | ||
if (length(x) == 1) cell = convert(T, fill(convert(eltype(T), x), size(T))) | ||
elseif (length(x) == length(T)) cell = convert(T, x) | ||
else @ArgumentError("fill: argument 'x' contains the wrong number of elements ($(length(x))). It must be a scalar or contain the number of elements defined by 'celldims'.") | ||
end | ||
return CellArrays.fill!(CuCellArray{T,B}(undef, args...), cell) | ||
end | ||
|
||
ParallelStencil.ParallelKernel.fill_cuda!(A, x) = CUDA.fill!(A, construct_cell(A, x)) | ||
|
||
check_datatype_cuda(args...) = check_datatype(args..., INT_CUDA) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
const ERRMSG_CUDAEXT_NOT_LOADED = "the CUDA extension was not loaded. Make sure to import CUDA before ParallelStencil." | ||
|
||
|
||
# shared.jl | ||
|
||
function get_priority_custream end | ||
function get_custream end | ||
|
||
|
||
# allocators.jl | ||
|
||
zeros_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
ones_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
rand_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
falses_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
trues_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
fill_cuda(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) | ||
fill_cuda!(arg...) = @NotLoadedError(ERRMSG_CUDAEXT_NOT_LOADED) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import ParallelStencil | ||
import ParallelStencil.ParallelKernel: INT_CUDA, rand_cpu, fill_cpu, construct_cell, check_datatype, rand_cuda, fill_cuda | ||
using ParallelStencil.ParallelKernel.Exceptions | ||
using CUDA, CellArrays, StaticArrays | ||
@define_CuCellArray | ||
|
||
|
||
## FUNCTIONS TO CHECK EXTENSIONS SUPPORT | ||
|
||
ParallelStencil.ParallelKernel.is_loaded(::Val{:ParallelStencil_CUDAExt}) = true | ||
|
||
|
||
## FUNCTIONS TO GET CREATE AND MANAGE CUDA STREAMS | ||
|
||
ParallelStencil.ParallelKernel.get_priority_custream(arg...) = get_priority_custream(arg...) | ||
ParallelStencil.ParallelKernel.get_custream(arg...) = get_custream(arg...) | ||
let | ||
global get_priority_custream, get_custream | ||
priority_custreams = Array{CuStream}(undef, 0) | ||
custreams = Array{CuStream}(undef, 0) | ||
|
||
function get_priority_custream(id::Integer) | ||
while (id > length(priority_custreams)) push!(priority_custreams, CuStream(; flags=CUDA.STREAM_NON_BLOCKING, priority=CUDA.priority_range()[end])) end # CUDA.priority_range()[end] is max priority. # NOTE: priority_range cannot be called outside the function as only at runtime sure that CUDA is functional. | ||
return priority_custreams[id] | ||
end | ||
|
||
function get_custream(id::Integer) | ||
while (id > length(custreams)) push!(custreams, CuStream(; flags=CUDA.STREAM_NON_BLOCKING, priority=CUDA.priority_range()[1])) end # CUDA.priority_range()[1] is min priority. # NOTE: priority_range cannot be called outside the function as only at runtime sure that CUDA is functional. | ||
return custreams[id] | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.