From 214dbd62cd5c09f052d6007c335d5ebee7c6fd89 Mon Sep 17 00:00:00 2001 From: Julien Schueller Date: Fri, 17 May 2024 09:02:20 +0200 Subject: [PATCH] Fix build with sundials 7 (#94) --- CHANGELOG | 3 + examples/kinsol_ors.py | 7 ++- setup.py | 6 +- src/lib/sundials_callbacks.pxi | 14 +++-- src/lib/sundials_callbacks_kinsol.pxi | 69 ++++++++++++----------- src/lib/sundials_includes.pxd | 12 +++- src/solvers/dasp3.py | 4 +- src/solvers/kinsol.pyx | 10 +++- src/solvers/odepack.py | 2 +- src/solvers/radau5.py | 14 ++--- src/solvers/runge_kutta.py | 2 +- src/solvers/sundials.pyx | 79 +++++++++++++++++++++------ tests/solvers/test_odassl.py | 1 - 13 files changed, 146 insertions(+), 77 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index b5cdf680..5820950e 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,8 @@ --- CHANGELOG --- +--- Assimulo-3.5.1 --- + * Fixed build with sundials 7.x + --- Assimulo-3.5.0 --- * Changed "numpy.float" to equivalent "numpy.float64" due to DeprecationWarnings in numpy >= 1.20. * Improved examples with sparse jacobians by omitting the zeros in the jacobians. diff --git a/examples/kinsol_ors.py b/examples/kinsol_ors.py index 29f263ab..ae0e5ec8 100644 --- a/examples/kinsol_ors.py +++ b/examples/kinsol_ors.py @@ -18,8 +18,9 @@ import os import nose import numpy as np -import scipy as sp import scipy.sparse as sps +import scipy.sparse.linalg as spsl +import scipy.io as spi from assimulo.solvers import KINSOL from assimulo.problem import Algebraic_Problem import warnings @@ -36,7 +37,7 @@ def run_example(with_plots=True): Iterative Methods for Sparse Linear Systems. """ #Read the original matrix - A_original = sp.io.mmread(os.path.join(file_path,"kinsol_ors_matrix.mtx")) + A_original = spi.mmread(os.path.join(file_path,"kinsol_ors_matrix.mtx")) #Scale the original matrix A = sps.spdiags(1.0/A_original.diagonal(), 0, len(A_original.diagonal()), len(A_original.diagonal())) * A_original @@ -51,7 +52,7 @@ def run_example(with_plots=True): U = D-F Prec = L.dot(U) - solvePrec = sps.linalg.factorized(Prec) + solvePrec = spsl.factorized(Prec) #Create the RHS b = A.dot(np.ones(A.shape[0])) diff --git a/setup.py b/setup.py index ffb58a62..e0648791 100644 --- a/setup.py +++ b/setup.py @@ -504,6 +504,8 @@ def cython_extensionlists(self): if self.SUNDIALS_version >= (3,0,0): ext_list[-1].libraries = ["sundials_cvodes", "sundials_nvecserial", "sundials_idas", "sundials_sunlinsoldense", "sundials_sunlinsolspgmr", "sundials_sunmatrixdense", "sundials_sunmatrixsparse"] + if self.SUNDIALS_version >= (7,0,0): + ext_list[-1].libraries.extend(["sundials_core"]) else: ext_list[-1].libraries = ["sundials_cvodes", "sundials_nvecserial", "sundials_idas"] if self.sundials_with_superlu and self.with_SLU: #If SUNDIALS is compiled with support for SuperLU @@ -523,6 +525,8 @@ def cython_extensionlists(self): ext_list[-1].include_dirs = [np.get_include(), "assimulo","assimulo"+os.sep+"lib", self.incdirs] ext_list[-1].library_dirs = [self.libdirs] ext_list[-1].libraries = ["sundials_kinsol", "sundials_nvecserial"] + if self.SUNDIALS_version >= (7,0,0): + ext_list[-1].libraries.extend(["sundials_core"]) if self.sundials_with_superlu and self.with_SLU: #If SUNDIALS is compiled with support for SuperLU ext_list[-1].include_dirs.append(self.SLUincdir) @@ -665,7 +669,7 @@ def fortran_extensionlists(self): NAME = "Assimulo" AUTHOR = u"C. Winther (Andersson), C. Führer, J. Åkesson, M. Gäfvert" AUTHOR_EMAIL = "christian.winther@modelon.com" -VERSION = "trunk" if version_number_arg == "Default" else version_number_arg +VERSION = "3.5.0-dev" if version_number_arg == "Default" else version_number_arg LICENSE = "LGPL" URL = "http://www.jmodelica.org/assimulo" DOWNLOAD_URL = "http://www.jmodelica.org/assimulo" diff --git a/src/lib/sundials_callbacks.pxi b/src/lib/sundials_callbacks.pxi index 1ea78894..e970aea8 100644 --- a/src/lib/sundials_callbacks.pxi +++ b/src/lib/sundials_callbacks.pxi @@ -24,7 +24,10 @@ from numpy cimport PyArray_DATA cdef N_Vector N_VNewEmpty_Euclidean(long int n) noexcept: IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = 0 + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector v = N_VNew_Serial(n, ctx) ELSE: @@ -39,7 +42,10 @@ cdef inline N_Vector arr2nv(x) noexcept: cdef void* data_ptr=PyArray_DATA(ndx) IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = 0 + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector v = N_VNew_Serial(n, ctx) ELSE: @@ -63,7 +69,7 @@ cdef inline void arr2nv_inplace(x, N_Vector out) noexcept: cdef void* data_ptr=PyArray_DATA(ndx) memcpy((out.content).data, data_ptr, n*sizeof(realtype)) -cdef inline np.ndarray nv2arr(N_Vector v) noexcept: +cdef inline np.ndarray nv2arr(N_Vector v): cdef long int n = (v.content).length cdef realtype* v_data = (v.content).data cdef np.ndarray[realtype, ndim=1, mode='c'] x=np.empty(n) @@ -82,7 +88,7 @@ cdef inline void nv2mat_inplace(int Ns, N_Vector *v, np.ndarray o) noexcept: for j in range(Nf): o[j,i] = (v[i].content).data[j] -cdef inline realtype2arr(realtype *data, int n) noexcept: +cdef inline realtype2arr(realtype *data, int n): """Create new numpy array from realtype*""" cdef np.ndarray[realtype, ndim=1, mode='c'] x=np.empty(n) memcpy(PyArray_DATA(x), data, n*sizeof(realtype)) diff --git a/src/lib/sundials_callbacks_kinsol.pxi b/src/lib/sundials_callbacks_kinsol.pxi index 3e637e98..3066c293 100644 --- a/src/lib/sundials_callbacks_kinsol.pxi +++ b/src/lib/sundials_callbacks_kinsol.pxi @@ -256,41 +256,40 @@ ELSE: #print(""%(fnorm, snorm, pData.TOL)) -cdef void kin_info(const char *module, const char *function, char *msg, void *eh_data) noexcept: - cdef ProblemDataEquationSolver pData = eh_data - cdef int flag - cdef realtype fnorm - - if str(function) == "KINSol" and "fnorm" in str(msg): - #fnorm = float(msg.split("fnorm = ")[-1].strip()) - flag = SUNDIALS.KINGetFuncNorm(pData.KIN_MEM, &fnorm) - pData.nl_fnorm.append(fnorm) - - pData.log.append([module, function, msg]) - - #print("KinsolInfo "%function) - #print(""%msg) - """ - # Get the number of iterations - KINGetNumNonlinSolvIters(kin_mem, &nniters) - - - /* Only output an iteration under certain conditions: - * 1. nle_solver_log > 2 - * 2. The calling function is either KINSolInit or KINSol - * 3. The message string starts with "nni" - * - * This approach gives one printout per iteration - - - if ("KINSolInit" in function or "KINSol" in function) and "nni" in msg: - print(""%nniters) - print("ivs", N_VGetArrayPointer(kin_mem->kin_uu), block->n)) - print("", kin_mem->kin_fnorm)) - print("residuals", - realtype* f = N_VGetArrayPointer(kin_mem->kin_fval); - f[i]*residual_scaling_factors[i]) - """ + + cdef void kin_info(const char *module, const char *function, char *msg, void *eh_data) noexcept: + cdef ProblemDataEquationSolver pData = eh_data + cdef int flag + cdef realtype fnorm + + if str(function) == "KINSol" and "fnorm" in str(msg): + #fnorm = float(msg.split("fnorm = ")[-1].strip()) + flag = SUNDIALS.KINGetFuncNorm(pData.KIN_MEM, &fnorm) + pData.nl_fnorm.append(fnorm) + + pData.log.append([module, function, msg]) + + #print("KinsolInfo "%function) + #print(""%msg) + """ + # Get the number of iterations + KINGetNumNonlinSolvIters(kin_mem, &nniters) + + /* Only output an iteration under certain conditions: + * 1. nle_solver_log > 2 + * 2. The calling function is either KINSolInit or KINSol + * 3. The message string starts with "nni" + * + * This approach gives one printout per iteration + + if ("KINSolInit" in function or "KINSol" in function) and "nni" in msg: + print(""%nniters) + print("ivs", N_VGetArrayPointer(kin_mem->kin_uu), block->n)) + print("", kin_mem->kin_fnorm)) + print("residuals", + realtype* f = N_VGetArrayPointer(kin_mem->kin_fval); + f[i]*residual_scaling_factors[i]) + """ cdef class ProblemDataEquationSolver: cdef: diff --git a/src/lib/sundials_includes.pxd b/src/lib/sundials_includes.pxd index 0480bf0d..671bc3db 100644 --- a/src/lib/sundials_includes.pxd +++ b/src/lib/sundials_includes.pxd @@ -36,7 +36,11 @@ IF SUNDIALS_VERSION >= (6,0,0): ctypedef _SUNContext * SUNContext cdef struct _SUNContext: pass - int SUNContext_Create(void* comm, SUNContext* ctx) noexcept + IF SUNDIALS_VERSION >= (7,0,0): + ctypedef int SUNComm + int SUNContext_Create(SUNComm comm, SUNContext* ctx) noexcept + ELSE: + int SUNContext_Create(void* comm, SUNContext* ctx) noexcept IF SUNDIALS_VERSION >= (7,0,0): cdef extern from "sundials/sundials_context.h": @@ -48,6 +52,8 @@ IF SUNDIALS_VERSION >= (6,0,0): cdef extern from "sundials/sundials_types.h": ctypedef double sunrealtype ctypedef bint sunbooleantype + IF SUNDIALS_VERSION >= (7,0,0): + cdef int SUN_COMM_NULL ctypedef double realtype ctypedef bint booleantype ELSE: @@ -494,7 +500,7 @@ ELSE: N_Vector tmp2, N_Vector tmp3) noexcept int CVSlsSetSparseJacFn(void *cvode_mem, CVSlsSparseJacFn jac) noexcept int CVSlsGetNumJacEvals(void *cvode_mem, long int *njevals) noexcept - cdef inline tuple version() noexcept: return (2,6,0) + cdef inline tuple version(): return (2,6,0) IF SUNDIALS_WITH_SUPERLU: cdef extern from "cvodes/cvodes_superlumt.h": int CVSuperLUMT(void *cvode_mem, int numthreads, int n, int nnz) noexcept @@ -579,7 +585,7 @@ cdef extern from "idas/idas.h": int IDAGetNumResEvals(void *ida_mem, long int *nrevals) #Number of res evals IF SUNDIALS_VERSION >= (4,0,0): int IDAGetNumJacEvals(void *ida_mem, long int *njevals) #Number of jac evals - int IDAGetNumResEvals(void *ida_mem, long int *nrevalsLS) #Number of res evals due to jac evals + int IDAGetNumLinResEvals(void *ida_mem, long int *nrevalsLS) #Number of res evals due to jac evals ELSE: int IDADlsGetNumJacEvals(void *ida_mem, long int *njevals) #Number of jac evals int IDADlsGetNumResEvals(void *ida_mem, long int *nrevalsLS) #Number of res evals due to jac evals diff --git a/src/solvers/dasp3.py b/src/solvers/dasp3.py index 3318aebf..504cbe0a 100644 --- a/src/solvers/dasp3.py +++ b/src/solvers/dasp3.py @@ -44,8 +44,8 @@ class DASP3ODE(Explicit_ODE): .. math:: - \\frac{\mathrm{d}y}{\mathrm{d}t} &= f(t,y,z) \;\;\; \\text{(N equations)} \\\\ - \\varepsilon\\frac{\mathrm{d}z}{\mathrm{d}t} &= G(t,y,z)\;\;\; \\text{(M equations)} + \\frac{\\mathrm{d}y}{\\mathrm{d}t} &= f(t,y,z) \\;\\;\\; \\text{(N equations)} \\\\ + \\varepsilon\\frac{\\mathrm{d}z}{\\mathrm{d}t} &= G(t,y,z)\\;\\;\\; \\text{(M equations)} If is assumed that the first system is non-stiff and that the stiffness of the second system is due to the parameter diff --git a/src/solvers/kinsol.pyx b/src/solvers/kinsol.pyx index de9bd9e3..341f75a1 100644 --- a/src/solvers/kinsol.pyx +++ b/src/solvers/kinsol.pyx @@ -183,7 +183,10 @@ cdef class KINSOL(Algebraic): cdef int flag #Used for return IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) self.y_temp = arr2nv(self.y) @@ -230,7 +233,10 @@ cdef class KINSOL(Algebraic): cpdef add_linear_solver(self): IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) if self.options["linear_solver"] == "DENSE": IF SUNDIALS_VERSION >= (3,0,0): diff --git a/src/solvers/odepack.py b/src/solvers/odepack.py index e3943c32..3c7bc8e4 100644 --- a/src/solvers/odepack.py +++ b/src/solvers/odepack.py @@ -48,7 +48,7 @@ class LSODAR(Explicit_ODE): .. math:: - \dot{y} = f(t,y), \quad y(t_0) = y_0. + \\dot{y} = f(t,y), \\quad y(t_0) = y_0. LSODAR automatically switches between using an ADAMS method or an BDF method and is also able to monitor events. diff --git a/src/solvers/radau5.py b/src/solvers/radau5.py index 16efa4ce..c92bcb31 100644 --- a/src/solvers/radau5.py +++ b/src/solvers/radau5.py @@ -16,7 +16,7 @@ # along with this program. If not, see . import numpy as np -import scipy as sp +import scipy.linalg as spl import scipy.sparse as sps from assimulo.exception import ( @@ -707,9 +707,9 @@ def newton(self,t,y): self._g = self._gamma/self.h self._B = self._g*self.I - self._jac - self._P1,self._L1,self._U1 = sp.linalg.lu(self._B) #LU decomposition - self._P2,self._L2,self._U2 = sp.linalg.lu(self._a*self.I-self._jac) - self._P3,self._L3,self._U3 = sp.linalg.lu(self._b*self.I-self._jac) + self._P1,self._L1,self._U1 = spl.lu(self._B) #LU decomposition + self._P2,self._L2,self._U2 = spl.lu(self._a*self.I-self._jac) + self._P3,self._L3,self._U3 = spl.lu(self._b*self.I-self._jac) self._needLU = False @@ -1526,9 +1526,9 @@ def newton(self,t,y,yd): self._g = self._gamma/self.h self._B = self._g*self.M - self._jac - self._P1,self._L1,self._U1 = sp.linalg.lu(self._B) #LU decomposition - self._P2,self._L2,self._U2 = sp.linalg.lu(self._a*self.M-self._jac) - self._P3,self._L3,self._U3 = sp.linalg.lu(self._b*self.M-self._jac) + self._P1,self._L1,self._U1 = spl.lu(self._B) #LU decomposition + self._P2,self._L2,self._U2 = spl.lu(self._a*self.M-self._jac) + self._P3,self._L3,self._U3 = spl.lu(self._b*self.M-self._jac) self._needLU = False diff --git a/src/solvers/runge_kutta.py b/src/solvers/runge_kutta.py index 913f8c15..98012718 100644 --- a/src/solvers/runge_kutta.py +++ b/src/solvers/runge_kutta.py @@ -753,7 +753,7 @@ class RungeKutta4(Explicit_ODE): .. math:: - \dot{y} = f(t,y), \quad y(t_0) = y_0 . + \\dot{y} = f(t,y), \\quad y(t_0) = y_0 . Using a Runge-Kutta method of order 4, the approximation is defined as follow, diff --git a/src/solvers/sundials.pyx b/src/solvers/sundials.pyx index 84cc9b6e..a019f539 100644 --- a/src/solvers/sundials.pyx +++ b/src/solvers/sundials.pyx @@ -239,7 +239,10 @@ cdef class IDA(Implicit_ODE): cdef realtype ZERO = 0.0 IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) self.yTemp = arr2nv(self.y) @@ -739,7 +742,10 @@ cdef class IDA(Implicit_ODE): cdef np.ndarray err, pyweight, pyele IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector ele = N_VNew_Serial(self.pData.dim, ctx) cdef N_Vector eweight = N_VNew_Serial(self.pData.dim, ctx) @@ -774,7 +780,10 @@ cdef class IDA(Implicit_ODE): cdef np.ndarray res IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector dky=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -815,7 +824,10 @@ cdef class IDA(Implicit_ODE): """ IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector dkyS=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -1429,13 +1441,11 @@ cdef class IDA(Implicit_ODE): &klast, &kcur, &hinused, &hlast, &hcur, &tcur) flag = SUNDIALS.IDAGetNonlinSolvStats(self.ida_mem, &nniters, &nncfails) flag = SUNDIALS.IDAGetNumGEvals(self.ida_mem, &ngevals) - #flag = SUNDIALS.IDADlsGetNumJacEvals(self.ida_mem, &njevals) - #flag = SUNDIALS.IDADlsGetNumResEvals(self.ida_mem, &nrevalsLS) - + if self.options["linear_solver"] == "SPGMR": IF SUNDIALS_VERSION >= (4,0,0): flag = SUNDIALS.IDAGetNumJtimesEvals(self.ida_mem, &njvevals) #Number of jac*vector - flag = SUNDIALS.IDAGetNumResEvals(self.ida_mem, &nfevalsLS) #Number of rhs due to jac*vector + flag = SUNDIALS.IDAGetNumLinResEvals(self.ida_mem, &nfevalsLS) #Number of rhs due to jac*vector ELSE: flag = SUNDIALS.IDASpilsGetNumJtimesEvals(self.ida_mem, &njvevals) #Number of jac*vector flag = SUNDIALS.IDASpilsGetNumResEvals(self.ida_mem, &nfevalsLS) #Number of rhs due to jac*vector @@ -1444,7 +1454,7 @@ cdef class IDA(Implicit_ODE): else: IF SUNDIALS_VERSION >= (4,0,0): flag = SUNDIALS.IDAGetNumJacEvals(self.ida_mem, &njevals) - flag = SUNDIALS.IDAGetNumResEvals(self.ida_mem, &nrevalsLS) + flag = SUNDIALS.IDAGetNumLinResEvals(self.ida_mem, &nrevalsLS) ELSE: flag = SUNDIALS.IDADlsGetNumJacEvals(self.ida_mem, &njevals) flag = SUNDIALS.IDADlsGetNumResEvals(self.ida_mem, &nrevalsLS) @@ -1612,9 +1622,16 @@ cdef class CVode(Explicit_ODE): Returns the vector of estimated local errors at the current step. """ cdef int flag + + if self.cvode_mem == NULL: + raise CVodeError(CV_MEM_FAIL) + IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector ele=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -1644,6 +1661,9 @@ cdef class CVode(Explicit_ODE): cdef int flag cdef int qlast + if self.cvode_mem == NULL: + raise CVodeError(CV_MEM_FAIL) + flag = SUNDIALS.CVodeGetLastOrder(self.cvode_mem, &qlast) if flag < 0: raise CVodeError(flag, self.t) @@ -1682,7 +1702,10 @@ cdef class CVode(Explicit_ODE): """ cdef int flag cdef int qcur - + + if self.cvode_mem == NULL: + raise CVodeError(CV_MEM_FAIL) + flag = SUNDIALS.CVodeGetCurrentOrder(self.cvode_mem, &qcur) if flag < 0: raise CVodeError(flag, self.t) @@ -1694,9 +1717,16 @@ cdef class CVode(Explicit_ODE): Returns the solution error weights at the current step. """ cdef int flag + + if self.cvode_mem == NULL: + raise CVodeError(CV_MEM_FAIL) + IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector eweight=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -1770,7 +1800,10 @@ cdef class CVode(Explicit_ODE): cdef realtype ZERO = 0.0 IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void * comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) if self.options["norm"] == "EUCLIDEAN": @@ -1930,7 +1963,10 @@ cdef class CVode(Explicit_ODE): cdef np.ndarray res IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector dky=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -1972,7 +2008,10 @@ cdef class CVode(Explicit_ODE): """ IF SUNDIALS_VERSION >= (6,0,0): cdef SUNDIALS.SUNContext ctx = NULL - cdef void* comm = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) cdef N_Vector dkyS=N_VNew_Serial(self.pData.dim, ctx) ELSE: @@ -2251,8 +2290,11 @@ cdef class CVode(Explicit_ODE): """ cdef flag IF SUNDIALS_VERSION >= (6,0,0): - cdef SUNDIALS.SUNContext ctx - cdef void* comm = NULL + cdef SUNDIALS.SUNContext ctx = NULL + IF SUNDIALS_VERSION >= (7,0,0): + cdef SUNDIALS.SUNComm comm = SUNDIALS.SUN_COMM_NULL + ELSE: + cdef void* comm = NULL SUNDIALS.SUNContext_Create(comm, &ctx) #Choose a linear solver if and only if NEWTON is choosen @@ -3229,6 +3271,9 @@ cdef class CVode(Explicit_ODE): cdef int qlast = 0, qcur = 0 cdef realtype hinused = 0.0, hlast = 0.0, hcur = 0.0, tcur = 0.0 + if self.cvode_mem == NULL: + raise CVodeError(CV_MEM_FAIL) + if self.options["linear_solver"] == "SPGMR": IF SUNDIALS_VERSION >= (4,0,0): flag = SUNDIALS.CVodeGetNumJtimesEvals(self.cvode_mem, &njvevals) #Number of jac*vector diff --git a/tests/solvers/test_odassl.py b/tests/solvers/test_odassl.py index 973ac05d..05452822 100644 --- a/tests/solvers/test_odassl.py +++ b/tests/solvers/test_odassl.py @@ -15,7 +15,6 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . -import nose from assimulo import testattr from assimulo.solvers.odassl import ODASSL from assimulo.problem import Implicit_Problem