Skip to content

Commit

Permalink
Use a String for the structure
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 8, 2023
1 parent f5e5798 commit d36e766
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 28 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ A_gpu = CuSparseMatrixCSR(A_cpu)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

solver = CudssSolver(A_gpu, 'G', 'F')
solver = CudssSolver(A_gpu, "G", 'F')

cudss("analysis", solver, x_gpu, b_gpu)
cudss("factorization", solver, x_gpu, b_gpu)
Expand Down Expand Up @@ -75,7 +75,7 @@ A_gpu = CuSparseMatrixCSR(A_cpu |> tril)
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

structure = T <: Real ? 'S' : 'H'
structure = T <: Real ? "S" : "H"
solver = CudssSolver(A_gpu, structure, 'L')

cudss("analysis", solver, X_gpu, B_gpu)
Expand Down
10 changes: 5 additions & 5 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ export CudssMatrix, CudssData, CudssConfig
"""
matrix = CudssMatrix(v::CuVector)
matrix = CudssMatrix(A::CuMatrix)
matrix = CudssMatrix(A::CuSparseMatrixCSR, struture::Union{Char, String}, view::Char; index::Char='O')
matrix = CudssMatrix(A::CuSparseMatrixCSR, struture::String, view::Char; index::Char='O')
`CudssMatrix` is a wrapper for `CuVector`, `CuMatrix` and `CuSparseMatrixCSR`.
`CudssMatrix` is used to pass matrix of the linear system, as well as solution and right-hand side.
`structure` specifies the stucture for sparse matrices:
- `'G'` or `"G"`: General matrix -- LDU factorization;
- `'S'` or `"S"`: Real symmetric matrix -- LDLᵀ factorization;
- `'H'` or `"H"`: Complex Hermitian matrix -- LDLᴴ factorization;
- `"G"`: General matrix -- LDU factorization;
- `"S"`: Real symmetric matrix -- LDLᵀ factorization;
- `"H"`: Complex Hermitian matrix -- LDLᴴ factorization;
- `"SPD"`: Symmetric positive-definite matrix -- LLᵀ factorization;
- `"HPD"`: Hermitian positive-definite matrix -- LLᴴ factorization.
Expand Down Expand Up @@ -52,7 +52,7 @@ mutable struct CudssMatrix
obj
end

function CudssMatrix(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O')
function CudssMatrix(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
m,n = size(A)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL,
Expand Down
10 changes: 5 additions & 5 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
export CudssSolver, cudss, cudss_set, cudss_get

"""
solver = CudssSolver(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O')
solver = CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
solver = CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData)
`CudssSolver` contains all structures required to solve linear systems with cuDSS.
One constructor of `CudssSolver` takes as input the same parameters as [`CudssMatrix`](@ref).
`structure` specifies the stucture for sparse matrices:
- `'G'` or `"G"`: General matrix -- LDU factorization;
- `'S'` or `"S"`: Real symmetric matrix -- LDLᵀ factorization;
- `'H'` or `"H"`: Complex Hermitian matrix -- LDLᴴ factorization;
- `"G"`: General matrix -- LDU factorization;
- `"S"`: Real symmetric matrix -- LDLᵀ factorization;
- `"H"`: Complex Hermitian matrix -- LDLᴴ factorization;
- `"SPD"`: Symmetric positive-definite matrix -- LLᵀ factorization;
- `"HPD"`: Hermitian positive-definite matrix -- LLᴴ factorization.
Expand All @@ -34,7 +34,7 @@ mutable struct CudssSolver
return new(matrix, config, data)
end

function CudssSolver(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O')
function CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
matrix = CudssMatrix(A, structure, view; index)
config = CudssConfig()
data = CudssData()
Expand Down
12 changes: 0 additions & 12 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ end

## matrix structure type

function Base.convert(::Type{cudssMatrixType_t}, structure::Char)
if structure == 'G'
return CUDSS_MTYPE_GENERAL
elseif structure == 'S'
return CUDSS_MTYPE_SYMMETRIC
elseif structure == 'H'
return CUDSS_MTYPE_HERMITIAN
else
throw(ArgumentError("Unknown structure $structure"))
end
end

function Base.convert(::Type{cudssMatrixType_t}, structure::String)
if structure == "G"
return CUDSS_MTYPE_GENERAL
Expand Down
8 changes: 4 additions & 4 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function cudss_solver()
A_cpu = sprand(T, n, n, 1.0)
A_cpu = A_cpu + A_cpu'
A_gpu = CuSparseMatrixCSR(A_cpu)
@testset "structure = $structure" for structure in ('G', "G", 'S', "S", 'H', "H", "SPD", "HPD")
@testset "structure = $structure" for structure in ("G", "S", "H", "SPD", "HPD")
@testset "view = $view" for view in ('L', 'U', 'F')
solver = CudssSolver(A_gpu, structure, view)

Expand Down Expand Up @@ -95,7 +95,7 @@ function cudss_solver()
@testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS
parameter ("perm_row", "perm_col", "perm_reorder", "diag") && continue
if parameter "user_perm"
(parameter == "inertia") && !(structure ('S', "S", 'H', "H")) && continue
(parameter == "inertia") && !(structure ("S", "H")) && continue
val = cudss_get(solver, parameter)
else
perm = Cint[i for i=n:-1:1]
Expand All @@ -121,7 +121,7 @@ function cudss_execution()
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

matrix = CudssMatrix(A_gpu, 'G', 'F')
matrix = CudssMatrix(A_gpu, "G", 'F')
config = CudssConfig()
data = CudssData()
solver = CudssSolver(matrix, config, data)
Expand All @@ -147,7 +147,7 @@ function cudss_execution()
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

structure = T <: Real ? 'S' : 'H'
structure = T <: Real ? "S" : "H"
matrix = CudssMatrix(A_gpu, structure, view)
config = CudssConfig()
data = CudssData()
Expand Down

0 comments on commit d36e766

Please sign in to comment.