Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Mar 29, 2024
1 parent 8880ad4 commit d325a8c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 32 deletions.
6 changes: 2 additions & 4 deletions src/host/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@ function next_rand(state::NTuple{4, T}) where {T <: Unsigned}
end

function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T
threadid = GPUArrays.threadidx(ctx)
stateful_rand = next_rand(randstate[threadid])
randstate[threadid] = stateful_rand[1]
return make_rand_num(T, stateful_rand[2])
end

function gpu_rand(::Type{T}, threadid, randstate::AbstractVector{NTuple{4, UInt32}}) where T <: Integer
threadid = GPUArrays.threadidx(ctx)
result = zero(T)
if sizeof(T) >= 4
for _ in 1:sizeof(T) >> 2
Expand Down Expand Up @@ -86,7 +84,7 @@ end
function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
@kernel function rand!(a, randstate)
idx = @index(Global, Linear)
@inbounds a[idx] = gpu_rand(T, idx, randstate)
@inbounds a[idx] = gpu_rand(T, ((idx-1)%length(randstate)+1), randstate)
end
rand!(get_backend(A))(A, rng.state, ndrange = size(A))
A
Expand All @@ -108,7 +106,7 @@ function Random.randn!(rng::RNG, A::AnyGPUArray{T}) where T <: Number
end
end
kernel = randn!(get_backend(A))
kernel(A, rng.states; ndrange=threads)
kernel(A, rng.state; ndrange=threads)
A
end

Expand Down
40 changes: 14 additions & 26 deletions src/host/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,16 @@ const unittriangularwrappers = (
(:UnitLowerTriangular, :LowerTriangular)
)

@kernel function kernel_generic(ctx, B, J, min_size)
@kernel function kernel_generic(B, J)
lin_idx = @index(Global, Linear)
if lin_idx <= min_size
@inbounds diag_idx = diagind(B)[lin_idx]
@inbounds B[diag_idx] += J
end
@inbounds diag_idx = diagind(B)[lin_idx]
@inbounds B[diag_idx] += J
end

@kernel function kernel_unittriangular(ctx, B, J, diagonal_val, min_size)
@kernel function kernel_unittriangular(B, J, diagonal_val)
lin_idx = @index(Global, Linear)
if lin_idx <= min_size
@inbounds diag_idx = diagind(B)[lin_idx]
@inbounds B[diag_idx] = diagonal_val + J
end
@inbounds diag_idx = diagind(B)[lin_idx]
@inbounds B[diag_idx] = diagonal_val + J
end

for (t1, t2) in unittriangularwrappers
Expand All @@ -34,17 +30,15 @@ for (t1, t2) in unittriangularwrappers
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
kernel = kernel_unittriangular(get_backend(B))
kernel(B, J, one(eltype(B)), min_size; ndrange=min_size)
kernel_unittriangular(get_backend(B))(B, J, one(eltype(B)); ndrange=min_size)
return $t2(B)
end

function (-)(J::UniformScaling, A::$t1{T, <:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .- parent(A)
min_size = minimum(size(B))
kernel = kernel_unittriangular(get_backend(B))
kernel(B, J, -one(eltype(B)), min_size; ndrange=min_size)
kernel_unittriangular(get_backend(B))(B, J, -one(eltype(B)); ndrange=min_size)
return $t2(B)
end
end
Expand All @@ -56,17 +50,15 @@ for t in genericwrappers
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return $t(B)
end

function (-)(J::UniformScaling, A::$t{T, <:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .- parent(A)
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return $t(B)
end
end
Expand All @@ -77,17 +69,15 @@ function (+)(A::Hermitian{T,<:AbstractGPUMatrix}, J::UniformScaling{<:Complex})
B = similar(parent(A), typeof(oneunit(T) + J))
copyto!(B, parent(A))
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return B
end

function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,<:AbstractGPUMatrix}) where T
B = similar(parent(A), typeof(J - oneunit(T)))
B .= .-parent(A)
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return B
end

Expand All @@ -96,16 +86,14 @@ function (+)(A::AbstractGPUMatrix{T}, J::UniformScaling) where T
B = similar(A, typeof(oneunit(T) + J))
copyto!(B, A)
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return B
end

function (-)(J::UniformScaling, A::AbstractGPUMatrix{T}) where T
B = similar(A, typeof(J - oneunit(T)))
B .= .-A
min_size = minimum(size(B))
kernel = kernel_generic(get_backend(B))
kernel(B, J, min_size; ndrange=min_size)
kernel_generic(get_backend(B))(B, J; ndrange=min_size)
return B
end
2 changes: 0 additions & 2 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ macro testsuite(name, ex)
end

include("testsuite/construction.jl")
#=
include("testsuite/indexing.jl")
include("testsuite/base.jl")
#include("testsuite/vector.jl")
Expand All @@ -98,7 +97,6 @@ include("testsuite/random.jl")
include("testsuite/uniformscaling.jl")
include("testsuite/statistics.jl")

=#
"""
Runs the entire GPUArrays test suite on array type `AT`
"""
Expand Down
2 changes: 2 additions & 0 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ Base.size(A::WrapArray) = size(A.data)
Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data))
# For broadcast support
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
KernelAbstractions.get_backend(a::WA) where WA <: WrapArray = get_backend(a.data)


function unknown_wrapper(AT, eltypes)
for ET in eltypes
Expand Down

0 comments on commit d325a8c

Please sign in to comment.