Skip to content

Commit 782efcd

Browse files
Stuart DainesChrisRackauckas
Stuart Daines
authored andcommitted
Replace Base.convert(::Type{SparseArrays.SparseMatrixCSC}, J::SUNMatrix) with copyto!
Fix for #315 Proximate cause was that the SUNMatrix sparse Jacobian was not fully initialised, hence failed the SparseMatrixCSC consistency checks introduces with Julia 1.7 More fundamentally, given that Sundials uses zero-based index arrays, convert(::Type{SparseArrays.SparseMatrixCSC}, J::SUNMatrix) would either need to allocate, or is inherently unsafe as an in-place version that applied a 1-based index conversion would create an invalid SUNMatrix. Also the convert code here looks like it modified colptr but not rowptr, so couldn't have worked anyway... (cvodejac and similar used to overwrite this, which is why those calls worked).
1 parent f67f1a6 commit 782efcd

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

src/common_interface/function_types.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,14 @@ function cvodejac(
6565
tmp3::N_Vector,
6666
)
6767
jac_prototype = funjac.jac_prototype
68-
J = convert(SparseArrays.SparseMatrixCSC, _J)
69-
68+
7069
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
7170
_u = funjac.u
7271

7372
funjac.jac(jac_prototype, _u, funjac.p, t)
74-
J.nzval .= jac_prototype.nzval
75-
# Sundials resets the value pointers each time, so reset it too
76-
@. J.rowval = jac_prototype.rowval - 1
77-
@. J.colptr = jac_prototype.colptr - 1
73+
74+
copyto!(_J, jac_prototype)
75+
7876
return CV_SUCCESS
7977
end
8078

@@ -128,7 +126,6 @@ function idajac(
128126
)
129127

130128
jac_prototype = funjac.jac_prototype
131-
J = convert(SparseArrays.SparseMatrixCSC, _J)
132129

133130
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
134131
_u = funjac.u
@@ -137,10 +134,8 @@ function idajac(
137134
_du = funjac.du
138135

139136
funjac.jac(jac_prototype, _du, _u, funjac.p, cj, t)
140-
J.nzval .= jac_prototype.nzval
141-
# Sundials resets the value pointers each time, so reset it too
142-
@. J.rowval = jac_prototype.rowval - 1
143-
@. J.colptr = jac_prototype.colptr - 1
137+
138+
copyto!(_J, jac_prototype)
144139

145140
return IDA_SUCCESS
146141
end
@@ -155,10 +150,10 @@ function massmat(
155150
)
156151
if typeof(mmf.mass_matrix) <: Array
157152
M = convert(Matrix, _M)
153+
M .= mmf.mass_matrix
158154
else
159-
M = convert(SparseArrays.SparseMatrixCSC, _M)
155+
copyto!(_M, mmf.mass_matrix)
160156
end
161-
M .= mmf.mass_matrix
162157

163158
return IDA_SUCCESS
164159
end

src/types_and_consts_additions.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,21 @@ function Base.convert(::Type{Matrix}, J::SUNMatrix)
3939
unsafe_wrap(Array, mat.data, (mat.M, mat.N), own = false)
4040
end
4141

42-
function Base.convert(::Type{SparseArrays.SparseMatrixCSC}, J::SUNMatrix)
43-
_sunmat = unsafe_load(J)
42+
# sparse SUNMatrix uses zero-offset indices, so provide copyto!, not convert
43+
function Base.copyto!(Asun::SUNMatrix, Acsc::SparseArrays.SparseMatrixCSC)
44+
_sunmat = unsafe_load(Asun)
4445
_mat = convert(SUNMatrixContent_Sparse, _sunmat.content)
4546
mat = unsafe_load(_mat)
46-
# own is false as memory is allocated by sundials
47-
# TODO: Get rid of allocation for 1-based index change
48-
rowval = unsafe_wrap(Array, mat.indexvals, (mat.NNZ), own = false)
49-
colptr = unsafe_wrap(Array, mat.indexptrs, (mat.NP + 1), own = false)
50-
colptr .+= 1
51-
m = mat.M
52-
n = mat.N
53-
nzval = unsafe_wrap(Array, mat.data, (mat.NNZ), own = false)
54-
SparseArrays.SparseMatrixCSC(m, n, colptr, rowval, nzval)
47+
# own is false as memory is allocated by sundials
48+
indexvals = unsafe_wrap(Array, mat.indexvals, (mat.NNZ), own = false)
49+
indexptrs = unsafe_wrap(Array, mat.indexptrs, (mat.NP + 1), own = false)
50+
data = unsafe_wrap(Array, mat.data, (mat.NNZ), own = false)
51+
52+
@. indexvals = Acsc.rowval - 1
53+
@. indexptrs = Acsc.colptr - 1
54+
data .= Acsc.nzval
55+
56+
return nothing
5557
end
5658

5759
abstract type SundialsMatrix end

0 commit comments

Comments
 (0)