Skip to content

Commit

Permalink
sundials
Browse files Browse the repository at this point in the history
  • Loading branch information
jschueller committed May 17, 2024
1 parent c915fe4 commit dfe4211
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 201 deletions.
50 changes: 46 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,25 @@ endif ()

if (USE_SUNDIALS)
find_package (SUNDIALS CONFIG)
if (NOT SUNDIALS_FOUND)
# fallback to our module
find_package (SUNDIALS MODULE)
set(SUNDIALS_DIR ${SUNDIALS_INCLUDE_DIR})
endif ()

if (SUNDIALS_FOUND)
message(STATUS "Found SUNDIALS: ${SUNDIALS_DIR} (found version \"${SUNDIALS_VERSION}\")")

string (REGEX REPLACE "([0-9]+)\\..*" "\\1" SUNDIALS_MAJOR_VERSION "${SUNDIALS_VERSION}")
string (REGEX REPLACE "[0-9]+\\.([0-9]+).*" "\\1" SUNDIALS_MINOR_VERSION "${SUNDIALS_VERSION}")
string (REGEX REPLACE "[0-9]+\\.[0-9]+\\.([0-9]+).*" "\\1" SUNDIALS_PATCH_VERSION "${SUNDIALS_VERSION}")
math(EXPR SUNDIALS_VERSION_NR "100000 * ${SUNDIALS_MAJOR_VERSION} + 100 * ${SUNDIALS_MINOR_VERSION} + ${SUNDIALS_PATCH_VERSION}")

if (SuperLU_MT_FOUND)
set (SUNDIALS_WITH_SUPERLU True)
else ()
set (SUNDIALS_WITH_SUPERLU False)
endif ()
endif ()
endif ()

Expand Down Expand Up @@ -82,7 +99,7 @@ macro(assimulo_add_cython_module pyx_file)
add_custom_command(
OUTPUT ${name}.c
COMMENT "Making ${name}.c from ${name}.pyx"
COMMAND Python::Interpreter -m cython -o ${name}.c --3str --fast-fail
COMMAND Python::Interpreter -m cython -o ${name}.c --3str --fast-fail ${CYTHON_FLAGS}
-I ${CMAKE_CURRENT_SOURCE_DIR}/src
-I ${CMAKE_CURRENT_SOURCE_DIR}/src/lib
${CMAKE_CURRENT_BINARY_DIR}/${pyx_file}
Expand All @@ -108,9 +125,34 @@ assimulo_add_cython_module(assimulo/special_systems.pyx)
assimulo_add_cython_module(assimulo/support.pyx)
assimulo_add_cython_module(assimulo/solvers/euler.pyx
DESTINATION assimulo/solvers)
#if (SUNDIALS_FOUND)
# assimulo_add_cython_module(solvers/sundials.pyx)
#endif ()

if (SUNDIALS_FOUND)
set(CYTHON_FLAGS -E SUNDIALS_VERSION_NR=${SUNDIALS_VERSION_NR} -E SUNDIALS_VECTOR_SIZE=64 -E SUNDIALS_WITH_SUPERLU=${SUNDIALS_WITH_SUPERLU} -E SUNDIALS_CVODE_RTOL_VEC=False)

set (SUNDIALS_LIBRARIES sundials_cvodes sundials_nvecserial sundials_idas)
if (SUNDIALS_VERSION VERSION_GREATER_EQUAL 3)
list (APPEND SUNDIALS_LIBRARIES sundials_sunlinsoldense sundials_sunlinsolspgmr sundials_sunmatrixdense sundials_sunmatrixsparse)
endif ()
if (SUNDIALS_VERSION VERSION_GREATER_EQUAL 7)
list (APPEND SUNDIALS_LIBRARIES sundials_core)
endif ()
if (SUNDIALS_WITH_SUPERLU)
list (APPEND SUNDIALS_LIBRARIES sundials_sunlinsolsuperlumt)
endif ()

assimulo_add_cython_module(assimulo/solvers/sundials.pyx
DESTINATION assimulo/solvers)
target_link_libraries(sundials PRIVATE ${SUNDIALS_LIBRARIES})

set (SUNDIALS_LIBRARIES sundials_kinsol sundials_nvecserial)
if (SUNDIALS_VERSION VERSION_GREATER_EQUAL 7)
list (APPEND SUNDIALS_LIBRARIES sundials_core)
endif ()

assimulo_add_cython_module(assimulo/solvers/kinsol.pyx
DESTINATION assimulo/solvers)
target_link_libraries(kinsol PRIVATE ${SUNDIALS_LIBRARIES})
endif ()


assimulo_add_cython_module(thirdparty/radau5/radau5ode.pyx
Expand Down
47 changes: 47 additions & 0 deletions cmake/FindSUNDIALS.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# - Find SUNDIALS
# SUNDIALS, the SUite of Nonlinear and DIfferential/ALgebraic equation Solvers.
# https://computing.llnl.gov/projects/sundials
#
# The module defines the following variables:
# SUNDIALS_VERSION, the version string
# SUNDIALS_INCLUDE_DIRS, where to find sundials_dense.h, etc.
# SUNDIALS_LIBRARIES, the libraries needed to use SUNDIALS
# SUNDIALS_FOUND, If false, do not try to use SUNDIALS
# also defined, but not for general use are
#

find_path (SUNDIALS_INCLUDE_DIR sundials_config.h PATH_SUFFIXES sundials)

file (STRINGS ${SUNDIALS_INCLUDE_DIR}/sundials_config.h _VERSION_DEFINE_STRING REGEX "#define SUNDIALS_VERSION .*")
if (_VERSION_DEFINE_STRING)
string (REGEX REPLACE "#define SUNDIALS_VERSION \"([0-9\.]+)\"" "\\1" SUNDIALS_VERSION ${_VERSION_DEFINE_STRING})
endif ()

set(SUNDIALS_LIBRARIES)
set(SUNDIALS_COMPONENTS sunlinsoldense sunlinsolspgmr sunmatrixdense sunmatrixsparse core sunlinsolsuperlumt kinsol nvecserial cvode cvodes)
foreach (COMPONENT ${SUNDIALS_COMPONENTS})
string(TOUPPER "${COMPONENT}" COMPONENT_UPPER)
find_library (SUNDIALS_${COMPONENT_UPPER}_LIBRARY NAMES sundials_${COMPONENT})

if (SUNDIALS_${COMPONENT_UPPER}_LIBRARY)
list (APPEND SUNDIALS_LIBRARIES ${SUNDIALS_${COMPONENT_UPPER}_LIBRARY})
endif ()
endforeach ()

set (SUNDIALS_INCLUDE_DIRS ${SUNDIALS_INCLUDE_DIR})

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(SUNDIALS DEFAULT_MSG SUNDIALS_INCLUDE_DIRS SUNDIALS_CVODES_LIBRARY VERSION_VAR SUNDIALS_VERSION)

mark_as_advanced (
SUNDIALS_LIBRARIES
SUNDIALS_INCLUDE_DIR
SUNDIALS_INCLUDE_DIRS)

if(NOT TARGET SUNDIALS::ALL)
add_library(SUNDIALS::ALL UNKNOWN IMPORTED)
set_target_properties(SUNDIALS::ALL
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${SUNDIALS_INCLUDE_DIRS}")
target_link_libraries(SUNDIALS::ALL INTERFACE ${SUNDIALS_LIBRARIES})
endif()
6 changes: 3 additions & 3 deletions cmake/FindSuperLU_MT.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# https://github.com/xiaoyeli/superlu_mt
#
# The module defines the following variables:
# SUPERLUMT_INCLUDE_DIRS, where to find mpc.h, etc.
# SUPERLUMT_LIBRARIES, the libraries needed to use MPC.
# SUPERLUMT_FOUND, If false, do not try to use MPC.
# SUPERLUMT_INCLUDE_DIRS, where to find superlu.h, etc.
# SUPERLUMT_LIBRARIES, the libraries needed to use SuperLU_MT
# SUPERLUMT_FOUND, If false, do not try to use SuperLU_MT
# also defined, but not for general use are
# SUPERLUMT_LIBRARY, where to find the MPC library.
#
Expand Down
24 changes: 12 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def check_SUNDIALS(self):
if os.path.exists(os.path.join(os.path.join(self.incdirs,'cvodes'), 'cvodes.h')):
self.with_SUNDIALS=True
logging.debug('SUNDIALS found.')
sundials_version = None
sundials_version_tuple = None
sundials_vector_type_size = None
sundials_with_superlu = False
sundials_with_msvc = False
Expand All @@ -386,8 +386,8 @@ def check_SUNDIALS(self):
with open(os.path.join(os.path.join(self.incdirs,'sundials'), 'sundials_config.h')) as f:
for line in f:
if "SUNDIALS_PACKAGE_VERSION" in line or "SUNDIALS_VERSION" in line:
sundials_version = tuple([int(f) for f in line.split()[-1][1:-1].split('-dev')[0].split(".")])
logging.debug('SUNDIALS %d.%d found.'%(sundials_version[0], sundials_version[1]))
sundials_version_tuple = tuple([int(f) for f in line.split()[-1][1:-1].split('-dev')[0].split(".")])
logging.debug('SUNDIALS %d.%d found.'%(sundials_version_tuple[0], sundials_version_tuple[1]))
break
with open(os.path.join(os.path.join(self.incdirs,'sundials'), 'sundials_config.h')) as f:
for line in f:
Expand Down Expand Up @@ -418,13 +418,13 @@ def check_SUNDIALS(self):
sundials_with_msvc = True
except Exception:
if os.path.exists(os.path.join(os.path.join(self.incdirs,'arkode'), 'arkode.h')): #This was added in 2.6
sundials_version = (2,6,0)
sundials_version_tuple = (2,6,0)
logging.debug('SUNDIALS 2.6 found.')
else:
sundials_version = (2,5,0)
sundials_version_tuple = (2,5,0)
logging.debug('SUNDIALS 2.5 found.')
self.SUNDIALS_version = sundials_version

self.SUNDIALS_version_nr = 100000 * sundials_version_tuple[0] + 100 * sundials_version_tuple[1] + sundials_version_tuple[2]
self.SUNDIALS_vector_size = sundials_vector_type_size
self.sundials_with_superlu = sundials_with_superlu
self.sundials_with_msvc = sundials_with_msvc
Expand Down Expand Up @@ -489,7 +489,7 @@ def cython_extensionlists(self):

# SUNDIALS
if self.with_SUNDIALS:
compile_time_env = {'SUNDIALS_VERSION': self.SUNDIALS_version,
compile_time_env = {'SUNDIALS_VERSION_NR': self.SUNDIALS_version_nr,
'SUNDIALS_WITH_SUPERLU': self.sundials_with_superlu and self.with_SLU,
'SUNDIALS_VECTOR_SIZE': self.SUNDIALS_vector_size,
'SUNDIALS_CVODE_RTOL_VEC': self.sundials_cvode_with_rtol_vec}
Expand All @@ -502,14 +502,14 @@ def cython_extensionlists(self):
ext_list[-1].include_dirs = [np.get_include(), "assimulo","assimulo"+os.sep+"lib", self.incdirs]
ext_list[-1].library_dirs = [self.libdirs]

if self.SUNDIALS_version >= (3,0,0):
if self.SUNDIALS_version_nr >= 300000:
ext_list[-1].libraries = ["sundials_cvodes", "sundials_nvecserial", "sundials_idas", "sundials_sunlinsoldense", "sundials_sunlinsolspgmr", "sundials_sunmatrixdense", "sundials_sunmatrixsparse"]
if self.SUNDIALS_version >= (7,0,0):
if self.SUNDIALS_version_nr >= 700000:
ext_list[-1].libraries.extend(["sundials_core"])
else:
ext_list[-1].libraries = ["sundials_cvodes", "sundials_nvecserial", "sundials_idas"]
if self.sundials_with_superlu and self.with_SLU: #If SUNDIALS is compiled with support for SuperLU
if self.SUNDIALS_version >= (3,0,0):
if self.SUNDIALS_version_nr >= 300000:
ext_list[-1].libraries.extend(["sundials_sunlinsolsuperlumt"])

ext_list[-1].include_dirs.append(self.SLUincdir)
Expand All @@ -525,7 +525,7 @@ def cython_extensionlists(self):
ext_list[-1].include_dirs = [np.get_include(), "assimulo","assimulo"+os.sep+"lib", self.incdirs]
ext_list[-1].library_dirs = [self.libdirs]
ext_list[-1].libraries = ["sundials_kinsol", "sundials_nvecserial"]
if self.SUNDIALS_version >= (7,0,0):
if self.SUNDIALS_version_nr >= 700000:
ext_list[-1].libraries.extend(["sundials_core"])

if self.sundials_with_superlu and self.with_SLU: #If SUNDIALS is compiled with support for SuperLU
Expand Down
8 changes: 4 additions & 4 deletions src/lib/sundials_callbacks.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ from numpy cimport PyArray_DATA
#=================

cdef N_Vector N_VNewEmpty_Euclidean(long int n) noexcept:
IF SUNDIALS_VERSION >= (6,0,0):
IF SUNDIALS_VERSION_NR >= 600000:
cdef SUNDIALS.SUNContext ctx = NULL
IF SUNDIALS_VERSION >= (7,0,0):
IF SUNDIALS_VERSION_NR >= 700000:
cdef SUNDIALS.SUNComm comm = 0
ELSE:
cdef void* comm = NULL
Expand All @@ -40,9 +40,9 @@ cdef inline N_Vector arr2nv(x) noexcept:
cdef long int n = len(x)
cdef np.ndarray[realtype, ndim=1,mode='c'] ndx=x
cdef void* data_ptr=PyArray_DATA(ndx)
IF SUNDIALS_VERSION >= (6,0,0):
IF SUNDIALS_VERSION_NR >= 600000:
cdef SUNDIALS.SUNContext ctx = NULL
IF SUNDIALS_VERSION >= (7,0,0):
IF SUNDIALS_VERSION_NR >= 700000:
cdef SUNDIALS.SUNComm comm = 0
ELSE:
cdef void* comm = NULL
Expand Down
14 changes: 7 additions & 7 deletions src/lib/sundials_callbacks_ida_cvode.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ cdef int cv_sens_rhs_all(int Ns, realtype t, N_Vector yv, N_Vector yvdot,
return CV_UNREC_RHSFUNC_ERR


IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int cv_jac_sparse(realtype t, N_Vector yv, N_Vector fy, SUNMatrix Jac,
Expand Down Expand Up @@ -163,7 +163,7 @@ ELSE:
cdef np.ndarray[int, ndim=1, mode='c'] jindices
cdef np.ndarray[int, ndim=1, mode='c'] jindptr

IF SUNDIALS_VERSION >= (2,6,3):
IF SUNDIALS_VERSION_NR >= 200603:
cdef int* rowvals = Jacobian.rowvals[0]
cdef int* colptrs = Jacobian.colptrs[0]
ELSE:
Expand Down Expand Up @@ -212,7 +212,7 @@ ELSE:
return CVDLS_JACFUNC_UNRECVR


IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
cdef int cv_jac(realtype t, N_Vector yv, N_Vector fy, SUNMatrix Jac,
void *problem_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) noexcept:
"""
Expand Down Expand Up @@ -363,7 +363,7 @@ cdef int cv_jacv(N_Vector vv, N_Vector Jv, realtype t, N_Vector yv, N_Vector fyv
traceback.print_exc()
return SPGMR_PSOLVE_FAIL_UNREC

IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
cdef int cv_prec_setup(realtype t, N_Vector yy, N_Vector fyy,
bint jok, bint *jcurPtr,
realtype gamma, void *problem_data) noexcept:
Expand Down Expand Up @@ -558,7 +558,7 @@ cdef int ida_res(realtype t, N_Vector yv, N_Vector yvdot, N_Vector residual, voi
traceback.print_exc()
return IDA_RES_FAIL

IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
cdef int ida_jac(realtype t, realtype c, N_Vector yv, N_Vector yvdot, N_Vector residual, SUNMatrix Jac,
void *problem_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3) noexcept:
"""
Expand Down Expand Up @@ -744,7 +744,7 @@ cdef int ida_jacv(realtype t, N_Vector yy, N_Vector yp, N_Vector rr, N_Vector vv

# Error handling callback functions
# =================================
IF SUNDIALS_VERSION >= (7,0,0):
IF SUNDIALS_VERSION_NR >= 700000:
cdef extern from "sundials/sundials_context.h":
ctypedef _SUNContext * SUNContext
cdef struct _SUNContext:
Expand Down Expand Up @@ -776,7 +776,7 @@ ELSE:
if error_code < 0: #Error
print('[CVode Error]', msg)

IF SUNDIALS_VERSION >= (7,0,0):
IF SUNDIALS_VERSION_NR >= 700000:
cdef void ida_err(int line, const char* func, const char* file, const char* msg, SUNErrCode error_code, void* problem_data, SUNContext sunctx) noexcept:
"""
This method overrides the default handling of error messages.
Expand Down
8 changes: 4 additions & 4 deletions src/lib/sundials_callbacks_kinsol.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import cython
import traceback
from assimulo.exception import AssimuloRecoverableError

IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
cdef int kin_jac(N_Vector xv, N_Vector fval, SUNMatrix Jac,
void *problem_data, N_Vector tmp1, N_Vector tmp2) noexcept:
"""
Expand Down Expand Up @@ -67,7 +67,7 @@ ELSE:
except Exception:
return KINDLS_JACFUNC_RECVR #Recoverable Error (See Sundials description)

IF SUNDIALS_VERSION >= (6,0,0):
IF SUNDIALS_VERSION_NR >= 600000:
ctypedef bint kin_jacv_bool
ELSE:
ctypedef int kin_jacv_bool
Expand Down Expand Up @@ -115,7 +115,7 @@ cdef int kin_res(N_Vector xv, N_Vector fval, void *problem_data) noexcept:
traceback.print_exc()
return KIN_SYSFUNC_FAIL

IF SUNDIALS_VERSION >= (3,0,0):
IF SUNDIALS_VERSION_NR >= 300000:
cdef int kin_prec_solve(N_Vector u, N_Vector uscaleN, N_Vector fval,
N_Vector fscaleN, N_Vector v, void *problem_data) noexcept:
"""
Expand Down Expand Up @@ -221,7 +221,7 @@ ELSE:

return KIN_SUCCESS

IF SUNDIALS_VERSION >= (7,0,0):
IF SUNDIALS_VERSION_NR >= 700000:
cdef extern from "sundials/sundials_context.h":
ctypedef _SUNContext * SUNContext
cdef struct _SUNContext:
Expand Down
Loading

0 comments on commit dfe4211

Please sign in to comment.