Skip to content

Commit ab1cd7c

Browse files
committed
wrap_gen: add wrappers supporting Julia objects
- add Handle{} (based on PR #29) that provides automatic memory management for CVODEMem, KINMem and IDAMem - add NVector wrapper (based on PR #29) that provides automatic memory management and conversion to/from Julia arrays for N_Vector - generate wrapper functions accepting Julia objects (Handles, NVectors, Vectors, Arrays); - the names of the "bare" wrappers that pass the arguments as-is are mangled by "__" (to avoid stack overflow if argument types are incompatible)
1 parent 891e391 commit ab1cd7c

12 files changed

+3601
-1421
lines changed

src/Sundials.jl

Lines changed: 11 additions & 337 deletions
Large diffs are not rendered by default.

src/cvode.jl

Lines changed: 414 additions & 138 deletions
Large diffs are not rendered by default.

src/cvodes.jl

Lines changed: 795 additions & 265 deletions
Large diffs are not rendered by default.

src/handle.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
##################################################################
2+
#
3+
# Pointers to Sundials objects
4+
#
5+
##################################################################
6+
7+
"""
8+
Base type for dummy placeholders that help to
9+
providing typed pointers for Sundials objects
10+
(KINSOL, CVODE, IDA).
11+
12+
See `Handle`.
13+
"""
14+
abstract AbstractSundialsObject
15+
16+
immutable CVODEMem <: AbstractSundialsObject end
17+
typealias CVODEMemPtr Ptr{CVODEMem}
18+
19+
immutable IDAMem <: AbstractSundialsObject end
20+
typealias IDAMemPtr Ptr{IDAMem}
21+
22+
immutable KINMem <: AbstractSundialsObject end
23+
typealias KINMemPtr Ptr{KINMem}
24+
25+
"""
26+
Handle for Sundials objects (CVODE, IDA, KIN).
27+
28+
Wraps the reference to the pointer to the Sundials object.
29+
Manages automatic destruction of the referenced objects when it is
30+
no longer in use.
31+
"""
32+
immutable Handle{T <: AbstractSundialsObject}
33+
ptr_ref::Ref{Ptr{T}} # pointer to a pointer
34+
35+
function Base.call{T}(::Type{Handle}, ptr::Ptr{T})
36+
h = new{T}(Ref{Ptr{T}}(ptr))
37+
finalizer(h.ptr_ref, release_handle)
38+
return h
39+
end
40+
end
41+
42+
Base.convert{T}(::Type{Ptr{T}}, h::Handle{T}) = h.ptr_ref[]
43+
Base.convert{T}(::Type{Ptr{Ptr{T}}}, h::Handle{T}) = convert(Ptr{Ptr{T}}, h.ptr_ref[])
44+
45+
release_handle{T}(ptr_ref::Ref{Ptr{T}}) = throw(MethodError("Freeing objects of type $T not supported"))
46+
release_handle(ptr_ref::Ref{Ptr{KINMem}}) = KINSOLFree(ptr_ref)
47+
release_handle(ptr_ref::Ref{Ptr{CVODEMem}}) = CVodeFree(ptr_ref)
48+
release_handle(ptr_ref::Ref{Ptr{IDAMem}}) = IDAFree(ptr_ref)
49+
50+
##################################################################
51+
#
52+
# Convenience typealiases for Sundials handles
53+
#
54+
##################################################################
55+
56+
typealias CVODEh Handle{CVODEMem}
57+
typealias KINh Handle{KINMem}
58+
typealias IDAh Handle{IDAMem}

src/ida.jl

Lines changed: 426 additions & 142 deletions
Large diffs are not rendered by default.

src/idas.jl

Lines changed: 816 additions & 272 deletions
Large diffs are not rendered by default.

src/kinsol.jl

Lines changed: 312 additions & 104 deletions
Large diffs are not rendered by default.

src/nvector.jl

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,118 +17,174 @@ function N_VMake_Serial(vec_length::Clong,v_data::Ptr{realtype})
1717
ccall((:N_VMake_Serial,libsundials_nvecserial),N_Vector,(Clong,Ptr{realtype}),vec_length,v_data)
1818
end
1919

20-
function N_VCloneVectorArray_Serial(count::Cint,w::N_Vector)
20+
function __N_VCloneVectorArray_Serial(count::Cint,w::N_Vector)
2121
ccall((:N_VCloneVectorArray_Serial,libsundials_nvecserial),Ptr{N_Vector},(Cint,N_Vector),count,w)
2222
end
2323

24-
function N_VCloneVectorArrayEmpty_Serial(count::Cint,w::N_Vector)
24+
N_VCloneVectorArray_Serial(count,w) = __N_VCloneVectorArray_Serial(count,convert(N_Vector,w))
25+
26+
function __N_VCloneVectorArrayEmpty_Serial(count::Cint,w::N_Vector)
2527
ccall((:N_VCloneVectorArrayEmpty_Serial,libsundials_nvecserial),Ptr{N_Vector},(Cint,N_Vector),count,w)
2628
end
2729

30+
N_VCloneVectorArrayEmpty_Serial(count,w) = __N_VCloneVectorArrayEmpty_Serial(count,convert(N_Vector,w))
31+
2832
function N_VDestroyVectorArray_Serial(vs::Ptr{N_Vector},count::Cint)
2933
ccall((:N_VDestroyVectorArray_Serial,libsundials_nvecserial),Void,(Ptr{N_Vector},Cint),vs,count)
3034
end
3135

32-
function N_VPrint_Serial(v::N_Vector)
36+
function __N_VPrint_Serial(v::N_Vector)
3337
ccall((:N_VPrint_Serial,libsundials_nvecserial),Void,(N_Vector,),v)
3438
end
3539

36-
function N_VCloneEmpty_Serial(w::N_Vector)
40+
N_VPrint_Serial(v) = __N_VPrint_Serial(convert(N_Vector,v))
41+
42+
function __N_VCloneEmpty_Serial(w::N_Vector)
3743
ccall((:N_VCloneEmpty_Serial,libsundials_nvecserial),N_Vector,(N_Vector,),w)
3844
end
3945

40-
function N_VClone_Serial(w::N_Vector)
46+
N_VCloneEmpty_Serial(w) = __N_VCloneEmpty_Serial(convert(N_Vector,w))
47+
48+
function __N_VClone_Serial(w::N_Vector)
4149
ccall((:N_VClone_Serial,libsundials_nvecserial),N_Vector,(N_Vector,),w)
4250
end
4351

44-
function N_VDestroy_Serial(v::N_Vector)
52+
N_VClone_Serial(w) = __N_VClone_Serial(convert(N_Vector,w))
53+
54+
function __N_VDestroy_Serial(v::N_Vector)
4555
ccall((:N_VDestroy_Serial,libsundials_nvecserial),Void,(N_Vector,),v)
4656
end
4757

48-
function N_VSpace_Serial(v::N_Vector,lrw::Ptr{Clong},liw::Ptr{Clong})
58+
N_VDestroy_Serial(v) = __N_VDestroy_Serial(convert(N_Vector,v))
59+
60+
function __N_VSpace_Serial(v::N_Vector,lrw::Ptr{Clong},liw::Ptr{Clong})
4961
ccall((:N_VSpace_Serial,libsundials_nvecserial),Void,(N_Vector,Ptr{Clong},Ptr{Clong}),v,lrw,liw)
5062
end
5163

52-
function N_VGetArrayPointer_Serial(v::N_Vector)
64+
N_VSpace_Serial(v,lrw,liw) = __N_VSpace_Serial(convert(N_Vector,v),pointer(lrw),pointer(liw))
65+
66+
function __N_VGetArrayPointer_Serial(v::N_Vector)
5367
ccall((:N_VGetArrayPointer_Serial,libsundials_nvecserial),Ptr{realtype},(N_Vector,),v)
5468
end
5569

56-
function N_VSetArrayPointer_Serial(v_data::Ptr{realtype},v::N_Vector)
70+
N_VGetArrayPointer_Serial(v) = __N_VGetArrayPointer_Serial(convert(N_Vector,v))
71+
72+
function __N_VSetArrayPointer_Serial(v_data::Ptr{realtype},v::N_Vector)
5773
ccall((:N_VSetArrayPointer_Serial,libsundials_nvecserial),Void,(Ptr{realtype},N_Vector),v_data,v)
5874
end
5975

60-
function N_VLinearSum_Serial(a::realtype,x::N_Vector,b::realtype,y::N_Vector,z::N_Vector)
76+
N_VSetArrayPointer_Serial(v_data,v) = __N_VSetArrayPointer_Serial(pointer(v_data),convert(N_Vector,v))
77+
78+
function __N_VLinearSum_Serial(a::realtype,x::N_Vector,b::realtype,y::N_Vector,z::N_Vector)
6179
ccall((:N_VLinearSum_Serial,libsundials_nvecserial),Void,(realtype,N_Vector,realtype,N_Vector,N_Vector),a,x,b,y,z)
6280
end
6381

64-
function N_VConst_Serial(c::realtype,z::N_Vector)
82+
N_VLinearSum_Serial(a,x,b,y,z) = __N_VLinearSum_Serial(a,convert(N_Vector,x),b,convert(N_Vector,y),convert(N_Vector,z))
83+
84+
function __N_VConst_Serial(c::realtype,z::N_Vector)
6585
ccall((:N_VConst_Serial,libsundials_nvecserial),Void,(realtype,N_Vector),c,z)
6686
end
6787

68-
function N_VProd_Serial(x::N_Vector,y::N_Vector,z::N_Vector)
88+
N_VConst_Serial(c,z) = __N_VConst_Serial(c,convert(N_Vector,z))
89+
90+
function __N_VProd_Serial(x::N_Vector,y::N_Vector,z::N_Vector)
6991
ccall((:N_VProd_Serial,libsundials_nvecserial),Void,(N_Vector,N_Vector,N_Vector),x,y,z)
7092
end
7193

72-
function N_VDiv_Serial(x::N_Vector,y::N_Vector,z::N_Vector)
94+
N_VProd_Serial(x,y,z) = __N_VProd_Serial(convert(N_Vector,x),convert(N_Vector,y),convert(N_Vector,z))
95+
96+
function __N_VDiv_Serial(x::N_Vector,y::N_Vector,z::N_Vector)
7397
ccall((:N_VDiv_Serial,libsundials_nvecserial),Void,(N_Vector,N_Vector,N_Vector),x,y,z)
7498
end
7599

76-
function N_VScale_Serial(c::realtype,x::N_Vector,z::N_Vector)
100+
N_VDiv_Serial(x,y,z) = __N_VDiv_Serial(convert(N_Vector,x),convert(N_Vector,y),convert(N_Vector,z))
101+
102+
function __N_VScale_Serial(c::realtype,x::N_Vector,z::N_Vector)
77103
ccall((:N_VScale_Serial,libsundials_nvecserial),Void,(realtype,N_Vector,N_Vector),c,x,z)
78104
end
79105

80-
function N_VAbs_Serial(x::N_Vector,z::N_Vector)
106+
N_VScale_Serial(c,x,z) = __N_VScale_Serial(c,convert(N_Vector,x),convert(N_Vector,z))
107+
108+
function __N_VAbs_Serial(x::N_Vector,z::N_Vector)
81109
ccall((:N_VAbs_Serial,libsundials_nvecserial),Void,(N_Vector,N_Vector),x,z)
82110
end
83111

84-
function N_VInv_Serial(x::N_Vector,z::N_Vector)
112+
N_VAbs_Serial(x,z) = __N_VAbs_Serial(convert(N_Vector,x),convert(N_Vector,z))
113+
114+
function __N_VInv_Serial(x::N_Vector,z::N_Vector)
85115
ccall((:N_VInv_Serial,libsundials_nvecserial),Void,(N_Vector,N_Vector),x,z)
86116
end
87117

88-
function N_VAddConst_Serial(x::N_Vector,b::realtype,z::N_Vector)
118+
N_VInv_Serial(x,z) = __N_VInv_Serial(convert(N_Vector,x),convert(N_Vector,z))
119+
120+
function __N_VAddConst_Serial(x::N_Vector,b::realtype,z::N_Vector)
89121
ccall((:N_VAddConst_Serial,libsundials_nvecserial),Void,(N_Vector,realtype,N_Vector),x,b,z)
90122
end
91123

92-
function N_VDotProd_Serial(x::N_Vector,y::N_Vector)
124+
N_VAddConst_Serial(x,b,z) = __N_VAddConst_Serial(convert(N_Vector,x),b,convert(N_Vector,z))
125+
126+
function __N_VDotProd_Serial(x::N_Vector,y::N_Vector)
93127
ccall((:N_VDotProd_Serial,libsundials_nvecserial),realtype,(N_Vector,N_Vector),x,y)
94128
end
95129

96-
function N_VMaxNorm_Serial(x::N_Vector)
130+
N_VDotProd_Serial(x,y) = __N_VDotProd_Serial(convert(N_Vector,x),convert(N_Vector,y))
131+
132+
function __N_VMaxNorm_Serial(x::N_Vector)
97133
ccall((:N_VMaxNorm_Serial,libsundials_nvecserial),realtype,(N_Vector,),x)
98134
end
99135

100-
function N_VWrmsNorm_Serial(x::N_Vector,w::N_Vector)
136+
N_VMaxNorm_Serial(x) = __N_VMaxNorm_Serial(convert(N_Vector,x))
137+
138+
function __N_VWrmsNorm_Serial(x::N_Vector,w::N_Vector)
101139
ccall((:N_VWrmsNorm_Serial,libsundials_nvecserial),realtype,(N_Vector,N_Vector),x,w)
102140
end
103141

104-
function N_VWrmsNormMask_Serial(x::N_Vector,w::N_Vector,id::N_Vector)
142+
N_VWrmsNorm_Serial(x,w) = __N_VWrmsNorm_Serial(convert(N_Vector,x),convert(N_Vector,w))
143+
144+
function __N_VWrmsNormMask_Serial(x::N_Vector,w::N_Vector,id::N_Vector)
105145
ccall((:N_VWrmsNormMask_Serial,libsundials_nvecserial),realtype,(N_Vector,N_Vector,N_Vector),x,w,id)
106146
end
107147

108-
function N_VMin_Serial(x::N_Vector)
148+
N_VWrmsNormMask_Serial(x,w,id) = __N_VWrmsNormMask_Serial(convert(N_Vector,x),convert(N_Vector,w),convert(N_Vector,id))
149+
150+
function __N_VMin_Serial(x::N_Vector)
109151
ccall((:N_VMin_Serial,libsundials_nvecserial),realtype,(N_Vector,),x)
110152
end
111153

112-
function N_VWL2Norm_Serial(x::N_Vector,w::N_Vector)
154+
N_VMin_Serial(x) = __N_VMin_Serial(convert(N_Vector,x))
155+
156+
function __N_VWL2Norm_Serial(x::N_Vector,w::N_Vector)
113157
ccall((:N_VWL2Norm_Serial,libsundials_nvecserial),realtype,(N_Vector,N_Vector),x,w)
114158
end
115159

116-
function N_VL1Norm_Serial(x::N_Vector)
160+
N_VWL2Norm_Serial(x,w) = __N_VWL2Norm_Serial(convert(N_Vector,x),convert(N_Vector,w))
161+
162+
function __N_VL1Norm_Serial(x::N_Vector)
117163
ccall((:N_VL1Norm_Serial,libsundials_nvecserial),realtype,(N_Vector,),x)
118164
end
119165

120-
function N_VCompare_Serial(c::realtype,x::N_Vector,z::N_Vector)
166+
N_VL1Norm_Serial(x) = __N_VL1Norm_Serial(convert(N_Vector,x))
167+
168+
function __N_VCompare_Serial(c::realtype,x::N_Vector,z::N_Vector)
121169
ccall((:N_VCompare_Serial,libsundials_nvecserial),Void,(realtype,N_Vector,N_Vector),c,x,z)
122170
end
123171

124-
function N_VInvTest_Serial(x::N_Vector,z::N_Vector)
172+
N_VCompare_Serial(c,x,z) = __N_VCompare_Serial(c,convert(N_Vector,x),convert(N_Vector,z))
173+
174+
function __N_VInvTest_Serial(x::N_Vector,z::N_Vector)
125175
ccall((:N_VInvTest_Serial,libsundials_nvecserial),Cint,(N_Vector,N_Vector),x,z)
126176
end
127177

128-
function N_VConstrMask_Serial(c::N_Vector,x::N_Vector,m::N_Vector)
178+
N_VInvTest_Serial(x,z) = __N_VInvTest_Serial(convert(N_Vector,x),convert(N_Vector,z))
179+
180+
function __N_VConstrMask_Serial(c::N_Vector,x::N_Vector,m::N_Vector)
129181
ccall((:N_VConstrMask_Serial,libsundials_nvecserial),Cint,(N_Vector,N_Vector,N_Vector),c,x,m)
130182
end
131183

132-
function N_VMinQuotient_Serial(num::N_Vector,denom::N_Vector)
184+
N_VConstrMask_Serial(c,x,m) = __N_VConstrMask_Serial(convert(N_Vector,c),convert(N_Vector,x),convert(N_Vector,m))
185+
186+
function __N_VMinQuotient_Serial(num::N_Vector,denom::N_Vector)
133187
ccall((:N_VMinQuotient_Serial,libsundials_nvecserial),realtype,(N_Vector,N_Vector),num,denom)
134188
end
189+
190+
N_VMinQuotient_Serial(num,denom) = __N_VMinQuotient_Serial(convert(N_Vector,num),convert(N_Vector,denom))

src/nvector_wrapper.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Wrapper for Sundials `N_Vector` that
3+
uses Julia `Vector{realtype}` as the data container.
4+
5+
Implements `DenseVector` interface and
6+
manages automatic destruction of the referenced `N_Vector` when it is
7+
no longer in use.
8+
"""
9+
immutable NVector <: DenseVector{realtype}
10+
ref_nv::Ref{N_Vector} # reference to N_Vector
11+
v::Vector{realtype} # array that is referenced by N_Vector
12+
13+
function NVector(v::Vector{realtype})
14+
nv = new(Ref{N_Vector}(N_VMake_Serial(length(v), pointer(v))), v)
15+
finalizer(nv.ref_nv, release_handle)
16+
return nv
17+
end
18+
end
19+
20+
release_handle(ref_nv::Ref{N_Vector}) = N_VDestroy_Serial(ref_nv[])
21+
22+
Base.size(nv::NVector, d...) = size(nv.v, d...)
23+
Base.stride(nv::NVector, d::Integer) = stride(nv.v, d)
24+
25+
Base.getindex(nv::NVector, i::Real) = getindex(nv.v, i)
26+
Base.getindex(nv::NVector, i::AbstractArray) = getindex(nv.v, i)
27+
Base.getindex(nv::NVector, inds...) = getindex(nv.v, inds...)
28+
29+
Base.setindex!(nv::NVector, X, i::Real) = setindex!(nv.v, X, i)
30+
Base.setindex!(nv::NVector, X, i::AbstractArray) = setindex!(nv.v, X, i)
31+
Base.setindex!(nv::NVector, X, inds...) = setindex!(nv.v, X, inds...)
32+
33+
##################################################################
34+
#
35+
# Methods to convert between Julia Vectors and Sundials N_Vectors.
36+
#
37+
##################################################################
38+
39+
Base.convert(::Type{NVector}, v::Vector{realtype}) = NVector(v)
40+
Base.convert{T<:Real}(::Type{NVector}, v::Vector{T}) = NVector(copy!(similar(v, realtype), v))
41+
Base.convert(::Type{NVector}, nv::NVector) = nv
42+
Base.convert(::Type{N_Vector}, nv::NVector) = nv.ref_nv[]
43+
Base.convert(::Type{Vector{realtype}}, nv::NVector)= nv.v
44+
Base.convert(::Type{Vector}, nv::NVector)= nv.v
45+
46+
""" `N_Vector(v::Vector{T})`
47+
48+
Converts Julia `Vector` to `N_Vector`.
49+
50+
Implicitly creates `NVector` object that manages automatic
51+
destruction of `N_Vector` object when no longer in use.
52+
"""
53+
Base.convert(::Type{N_Vector}, v::Vector{realtype}) = N_Vector(NVector(v))
54+
Base.convert{T<:Real}(::Type{N_Vector}, v::Vector{T}) = N_Vector(NVector(v))
55+
56+
nvlength(x::N_Vector) = unsafe_load(unsafe_load(convert(Ptr{Ptr{Clong}}, x)))
57+
# asarray() creates an array pointing to N_Vector data, but does not take the ownership
58+
@inline asarray(x::N_Vector) = pointer_to_array(N_VGetArrayPointer_Serial(x), (nvlength(x),), false)
59+
@inline asarray(x::N_Vector, dims::Tuple) = pointer_to_array(N_VGetArrayPointer_Serial(x), dims, false)
60+
asarray(x::Vector{realtype}) = x
61+
asarray(x::Ptr{realtype}, dims::Tuple) = pointer_to_array(x, dims)
62+
@inline Base.convert(::Type{Vector{realtype}}, x::N_Vector) = asarray(x)
63+
@inline Base.convert(::Type{Vector}, x::N_Vector) = asarray(x)
64+
65+
nvector(x::Vector{realtype}) = NVector(x)
66+
#nvector(x::N_Vector) = x

0 commit comments

Comments
 (0)