Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cudss_set with the data parameter user_perm #27

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,9 @@ end
function cudss_set(data::CudssData, param::String, value)
(param ∈ CUDSS_DATA_PARAMETERS) || throw(ArgumentError("Unknown data parameter $param."))
(param == "user_perm") || throw(ArgumentError("Only the data parameter \"user_perm\" can be set."))
type = CUDSS_TYPES[param]
val = Ref{type}(value)
nbytes = sizeof(val)
cudssDataSet(handle(), data, param, val, nbytes)
(value isa Vector{Cint} || value isa CuVector{Cint}) || throw(ArgumentError("The permutation is neither a Vector{Cint} nor a CuVector{Cint}."))
nbytes = sizeof(value)
cudssDataSet(handle(), data, param, value, nbytes)
end

function cudss_set(config::CudssConfig, param::String, value)
Expand Down
4 changes: 2 additions & 2 deletions src/libcudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,14 @@ end
@checked function cudssDataSet(handle, data, param, value, sizeInBytes)
initialize_context()
@ccall libcudss.cudssDataSet(handle::cudssHandle_t, data::cudssData_t,
param::cudssDataParam_t, value::Ptr{Cvoid},
param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid},
sizeInBytes::Cint)::cudssStatus_t
end

@checked function cudssDataGet(handle, data, param, value, sizeInBytes, sizeWritten)
initialize_context()
@ccall libcudss.cudssDataGet(handle::cudssHandle_t, data::cudssData_t,
param::cudssDataParam_t, value::Ptr{Cvoid},
param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid},
sizeInBytes::Cint, sizeWritten::Ptr{Cint})::cudssStatus_t
end

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ include("test_cudss.jl")
@testset "Generic API" begin
cudss_generic()
end

@testset "User permutation" begin
user_permutation()
end
end
103 changes: 103 additions & 0 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,106 @@ function cudss_generic()
end
end
end

function user_permutation()
function permutation_lu(T, A_cpu, x_cpu, b_cpu, permutation)
A_gpu = CuSparseMatrixCSR(A_cpu)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

solver = CudssSolver(A_gpu, "G", 'F')

cudss_set(solver, "user_perm", permutation)

cudss("analysis", solver, x_gpu, b_gpu)
cudss("factorization", solver, x_gpu, b_gpu)
cudss("solve", solver, x_gpu, b_gpu)

nz = cudss_get(solver, "lu_nnz")
return nz
end

function permutation_ldlt(T, A_cpu, x_cpu, b_cpu, permutation)
A_gpu = CuSparseMatrixCSR(A_cpu |> tril)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

structure = T <: Real ? "S" : "H"
solver = CudssSolver(A_gpu, structure, 'L')
cudss_set(solver, "user_perm", permutation)

cudss("analysis", solver, x_gpu, b_gpu)
cudss("factorization", solver, x_gpu, b_gpu)
cudss("solve", solver, x_gpu, b_gpu)

nz = cudss_get(solver, "lu_nnz")
return nz
end

function permutation_llt(T, A_cpu, x_cpu, b_cpu, permutation)
A_gpu = CuSparseMatrixCSR(A_cpu |> triu)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

structure = T <: Real ? "SPD" : "HPD"
solver = CudssSolver(A_gpu, structure, 'U')
cudss_set(solver, "user_perm", permutation)

cudss("analysis", solver, x_gpu, b_gpu)
cudss("factorization", solver, x_gpu, b_gpu)
cudss("solve", solver, x_gpu, b_gpu)

nz = cudss_get(solver, "lu_nnz")
return nz
end

n = 1000
perm1_cpu = Vector{Cint}(undef, n)
perm2_cpu = Vector{Cint}(undef, n)
for i = 1:n
perm1_cpu[i] = i
perm2_cpu[i] = n-i+1
end
perm1_gpu = CuVector{Cint}(perm1_cpu)
perm2_gpu = CuVector{Cint}(perm2_cpu)
@testset "precision = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "LU" begin
A_cpu = sprand(T, n, n, 0.05) + I
x_cpu = zeros(T, n)
b_cpu = rand(T, n)
nz1_cpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm1_cpu)
nz2_cpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm2_cpu)
nz1_gpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm1_gpu)
nz2_gpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm2_gpu)
@test nz1_cpu == nz1_gpu
@test nz2_cpu == nz2_gpu
@test nz1_cpu != nz2_cpu
end
@testset "LDLᵀ / LDLᴴ" begin
A_cpu = sprand(T, n, n, 0.05) + I
A_cpu = A_cpu + A_cpu'
x_cpu = zeros(T, n)
b_cpu = rand(T, n)
nz1_cpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm1_cpu)
nz2_cpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm2_cpu)
nz1_gpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm1_gpu)
nz2_gpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm2_gpu)
@test nz1_cpu == nz1_gpu
@test nz2_cpu == nz2_gpu
@test nz1_cpu != nz2_cpu
end
@testset "LLᵀ / LLᴴ" begin
A_cpu = sprand(T, n, n, 0.01)
A_cpu = A_cpu * A_cpu' + I
x_cpu = zeros(T, n)
b_cpu = rand(T, n)
nz1_cpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm1_cpu)
nz2_cpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm2_cpu)
nz1_gpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm1_gpu)
nz2_gpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm2_gpu)
@test nz1_cpu == nz1_gpu
@test nz2_cpu == nz2_gpu
@test nz1_cpu != nz2_cpu
end
end
end
Loading