Skip to content

Commit

Permalink
tests pass up to random now.
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Mar 28, 2024
1 parent 23eaa50 commit 8880ad4
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 13 deletions.
6 changes: 6 additions & 0 deletions lib/JLArrays/src/JLArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, Dyn

export JLArray, JLVector, JLMatrix, jl, JLBackend

#
# Device functionality
#

const MAXTHREADS = 256

struct JLBackend <: KernelAbstractions.GPU
static::Bool
JLBackend(;static::Bool=false) = new(static)
Expand Down
3 changes: 2 additions & 1 deletion src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ end
@inbounds dest[I] = bc′[I]
end

broadcast_kernel(get_backend(dest))(dest, bc′, ndrange = size(dest))
# ndrange set for a possible 0D evaluation
broadcast_kernel(get_backend(dest))(dest, bc′, ndrange = length(size(dest)) > 0 ? size(dest) : (1,))

return dest
end
Expand Down
5 changes: 4 additions & 1 deletion src/host/construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ function Base.fill!(A::AnyGPUArray{T}, x) where T
idx = @index(Global, Linear)
@inbounds a[idx] = val
end
fill_kernel!(get_backend(A))(A, x, ndrange = size(A))

# ndrange set for a possible 0D evaluation
fill_kernel!(get_backend(A))(A, x,
ndrange = length(size(A)) > 0 ? size(A) : (1,))
A
end

Expand Down
14 changes: 7 additions & 7 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
@inbounds _A[j,i] = conj(_A[i,j])
end
end
U_conj!(get_backend(_A), ndrange = size(_A))
U_conj!(get_backend(A))(A, ndrange = size(A))
elseif uplo == 'U' && !conjugate
@kernel function U_noconj!(_A)
I = @index(Global, Cartesian)
Expand All @@ -101,7 +101,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
@inbounds _A[j,i] = _A[i,j]
end
end
U_noconj!(get_backend(_A))(_A, ndrange=size(_A))
U_noconj!(get_backend(A))(A, ndrange = size(A))
elseif uplo == 'L' && conjugate
@kernel function L_conj!(_A)
I = @index(Global, Cartesian)
Expand All @@ -110,7 +110,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
@inbounds _A[i,j] = conj(_A[j,i])
end
end
L_conj!(get_backend(_A))(_A, ndrange = size(_A))
L_conj!(get_backend(A))(A, ndrange = size(A))
elseif uplo == 'L' && !conjugate
@kernel function L_noconj!(_A)
I = @index(Global, Cartesian)
Expand All @@ -119,7 +119,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
@inbounds _A[i,j] = _A[j,i]
end
end
L_noconj!(get_backend(_A))(_A, ndrange = size(_A))
L_noconj!(get_backend(A))(A, ndrange = size(A))
else
throw(ArgumentError("uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
end
Expand Down Expand Up @@ -178,7 +178,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@inbounds _A[i, j] = zero(T)
end
end
tril_kernel!(get_backend(_A))(_A, _d, ndrange = size(_A))
tril_kernel!(get_backend(A))(A, d, ndrange = size(A))
return A
end

Expand All @@ -190,7 +190,7 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
@inbounds _A[i, j] = zero(T)
end
end
triu_kernel!(get_backend(_A))(_A, _d, ndrange = length(_A))
triu_kernel!(get_backend(A))(A, d, ndrange = size(A))
return A
end

Expand Down Expand Up @@ -423,7 +423,7 @@ LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b)

function generic_lmul!(s::Number, X::AbstractArray)
@kernel function lmul_kernel!(X, s)
i = @index(Global, linear)
i = @index(Global, Linear)
@inbounds X[i] = s*X[i]
end
lmul_kernel!(get_backend(X))(X, s, ndrange = size(X))
Expand Down
2 changes: 1 addition & 1 deletion src/host/math.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Base mathematical operations

function Base.clamp!(A::AnyGPUArray, low, high)
@kernel function clamp_kernel!(A::AnyGPUArray, low, high)
@kernel function clamp_kernel!(A, low, high)
I = @index(Global, Cartesian)
A[I] = clamp(A[I], low, high)
end
Expand Down
5 changes: 2 additions & 3 deletions src/host/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ end
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
@kernel function rand!(a, randstate)
idx = @index(Global, Linear)
@inbounds a[idx] = gpu_rand(T, idx, randstates)
@inbounds a[idx] = gpu_rand(T, idx, randstate)
end
kernel = rand!(get_backend(A))
kernel(A, rng.state)
rand!(get_backend(A))(A, rng.state, ndrange = size(A))
A
end

Expand Down

0 comments on commit 8880ad4

Please sign in to comment.