From d63448d0f0b37669f1237fe7d596a1dd4e85083b Mon Sep 17 00:00:00 2001 From: "Sebastian M. Ernst" Date: Tue, 9 Jan 2024 14:16:09 +0100 Subject: [PATCH] more ivp draft --- src/hapsira/core/math/ivp/_brentq.c | 130 +++++++++++++++++++++++++++ src/hapsira/core/math/ivp/_zeros.c | 132 ++++++++++++++++++++++++++++ src/hapsira/core/math/ivp/_zeros.h | 40 +++++++++ 3 files changed, 302 insertions(+) create mode 100644 src/hapsira/core/math/ivp/_brentq.c create mode 100644 src/hapsira/core/math/ivp/_zeros.c create mode 100644 src/hapsira/core/math/ivp/_zeros.h diff --git a/src/hapsira/core/math/ivp/_brentq.c b/src/hapsira/core/math/ivp/_brentq.c new file mode 100644 index 000000000..1d8df1eb8 --- /dev/null +++ b/src/hapsira/core/math/ivp/_brentq.c @@ -0,0 +1,130 @@ +/* Written by Charles Harris charles.harris@sdl.usu.edu */ + +#include +#include "zeros.h" + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +/* + At the top of the loop the situation is the following: + + 1. the root is bracketed between xa and xb + 2. xa is the most recent estimate + 3. xp is the previous estimate + 4. |fp| < |fb| + + The order of xa and xp doesn't matter, but assume xp < xb. Then xa lies to + the right of xp and the assumption is that xa is increasing towards the root. + In this situation we will attempt quadratic extrapolation as long as the + condition + + * |fa| < |fp| < |fb| + + is satisfied. That is, the function value is decreasing as we go along. + Note the 4 above implies that the right inequlity already holds. + + The first check is that xa is still to the left of the root. If not, xb is + replaced by xp and the interval reverses, with xb < xa. In this situation + we will try linear interpolation. That this has happened is signaled by the + equality xb == xp; + + The second check is that |fa| < |fb|. If this is not the case, we swap + xa and xb and resort to bisection. + +*/ + +double +brentq(callback_type f, double xa, double xb, double xtol, double rtol, + int iter, void *func_data_param, scipy_zeros_info *solver_stats) +{ + double xpre = xa, xcur = xb; + double xblk = 0., fpre, fcur, fblk = 0., spre = 0., scur = 0., sbis; + /* the tolerance is 2*delta */ + double delta; + double stry, dpre, dblk; + int i; + solver_stats->error_num = INPROGRESS; + + fpre = (*f)(xpre, func_data_param); + fcur = (*f)(xcur, func_data_param); + solver_stats->funcalls = 2; + if (fpre == 0) { + solver_stats->error_num = CONVERGED; + return xpre; + } + if (fcur == 0) { + solver_stats->error_num = CONVERGED; + return xcur; + } + if (signbit(fpre)==signbit(fcur)) { + solver_stats->error_num = SIGNERR; + return 0.; + } + solver_stats->iterations = 0; + for (i = 0; i < iter; i++) { + solver_stats->iterations++; + if (fpre != 0 && fcur != 0 && + (signbit(fpre) != signbit(fcur))) { + xblk = xpre; + fblk = fpre; + spre = scur = xcur - xpre; + } + if (fabs(fblk) < fabs(fcur)) { + xpre = xcur; + xcur = xblk; + xblk = xpre; + + fpre = fcur; + fcur = fblk; + fblk = fpre; + } + + delta = (xtol + rtol*fabs(xcur))/2; + sbis = (xblk - xcur)/2; + if (fcur == 0 || fabs(sbis) < delta) { + solver_stats->error_num = CONVERGED; + return xcur; + } + + if (fabs(spre) > delta && fabs(fcur) < fabs(fpre)) { + if (xpre == xblk) { + /* interpolate */ + stry = -fcur*(xcur - xpre)/(fcur - fpre); + } + else { + /* extrapolate */ + dpre = (fpre - fcur)/(xpre - xcur); + dblk = (fblk - fcur)/(xblk - xcur); + stry = -fcur*(fblk*dblk - fpre*dpre) + /(dblk*dpre*(fblk - fpre)); + } + if (2*fabs(stry) < MIN(fabs(spre), 3*fabs(sbis) - delta)) { + /* good short step */ + spre = scur; + scur = stry; + } else { + /* bisect */ + spre = sbis; + scur = sbis; + } + } + else { + /* bisect */ + spre = sbis; + scur = sbis; + } + + xpre = xcur; fpre = fcur; + if (fabs(scur) > delta) { + xcur += scur; + } + else { + xcur += (sbis > 0 ? delta : -delta); + } + + fcur = (*f)(xcur, func_data_param); + solver_stats->funcalls++; + } + solver_stats->error_num = CONVERR; + return xcur; +} diff --git a/src/hapsira/core/math/ivp/_zeros.c b/src/hapsira/core/math/ivp/_zeros.c new file mode 100644 index 000000000..0516ffab9 --- /dev/null +++ b/src/hapsira/core/math/ivp/_zeros.c @@ -0,0 +1,132 @@ +/* + * Helper function that calls a Python function with extended arguments + */ + +static PyObject * +call_solver(solver_type solver, PyObject *self, PyObject *args) +{ + double a, b, xtol, rtol, zero; + int iter, fulloutput, disp=1, flag=0; + scipy_zeros_parameters params; + scipy_zeros_info solver_stats; + PyObject *f, *xargs; + + if (!PyArg_ParseTuple(args, "OddddiOi|i", + &f, &a, &b, &xtol, &rtol, &iter, &xargs, &fulloutput, &disp)) { + PyErr_SetString(PyExc_RuntimeError, "Unable to parse arguments"); + return NULL; + } + if (xtol < 0) { + PyErr_SetString(PyExc_ValueError, "xtol must be >= 0"); + return NULL; + } + if (iter < 0) { + PyErr_SetString(PyExc_ValueError, "maxiter should be > 0"); + return NULL; + } + + params.function = f; + params.xargs = xargs; + + if (!setjmp(params.env)) { + /* direct return */ + solver_stats.error_num = 0; + zero = solver(scipy_zeros_functions_func, a, b, xtol, rtol, + iter, (void*)¶ms, &solver_stats); + } else { + /* error return from Python function */ + return NULL; + } + + if (solver_stats.error_num != CONVERGED) { + if (solver_stats.error_num == SIGNERR) { + PyErr_SetString(PyExc_ValueError, + "f(a) and f(b) must have different signs"); + return NULL; + } + if (solver_stats.error_num == CONVERR) { + if (disp) { + char msg[100]; + PyOS_snprintf(msg, sizeof(msg), + "Failed to converge after %d iterations.", + solver_stats.iterations); + PyErr_SetString(PyExc_RuntimeError, msg); + return NULL; + } + flag = CONVERR; + } + } + else { + flag = CONVERGED; + } + if (fulloutput) { + return Py_BuildValue("diii", + zero, solver_stats.funcalls, solver_stats.iterations, flag); + } + else { + return Py_BuildValue("d", zero); + } +} + +/* + * These routines interface with the solvers through call_solver + */ + +static PyObject * +_bisect(PyObject *self, PyObject *args) +{ + return call_solver(bisect,self,args); +} + +static PyObject * +_ridder(PyObject *self, PyObject *args) +{ + return call_solver(ridder,self,args); +} + +static PyObject * +_brenth(PyObject *self, PyObject *args) +{ + return call_solver(brenth,self,args); +} + +static PyObject * +_brentq(PyObject *self, PyObject *args) +{ + return call_solver(brentq,self,args); +} + +/* + * Standard Python module interface + */ + +static PyMethodDef +Zerosmethods[] = { + {"_bisect", _bisect, METH_VARARGS, "a"}, + {"_ridder", _ridder, METH_VARARGS, "a"}, + {"_brenth", _brenth, METH_VARARGS, "a"}, + {"_brentq", _brentq, METH_VARARGS, "a"}, + {NULL, NULL} +}; + +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_zeros", + NULL, + -1, + Zerosmethods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC +PyInit__zeros(void) +{ + PyObject *m; + + m = PyModule_Create(&moduledef); + + return m; +} diff --git a/src/hapsira/core/math/ivp/_zeros.h b/src/hapsira/core/math/ivp/_zeros.h new file mode 100644 index 000000000..14afcc13f --- /dev/null +++ b/src/hapsira/core/math/ivp/_zeros.h @@ -0,0 +1,40 @@ +/* Written by Charles Harris charles.harris@sdl.usu.edu */ + +/* Modified to not depend on Python everywhere by Travis Oliphant. + */ + +#ifndef ZEROS_H +#define ZEROS_H + +typedef struct { + int funcalls; + int iterations; + int error_num; +} scipy_zeros_info; + + +/* Must agree with _ECONVERGED, _ESIGNERR, _ECONVERR in zeros.py */ +#define CONVERGED 0 +#define SIGNERR -1 +#define CONVERR -2 +#define EVALUEERR -3 +#define INPROGRESS 1 + +typedef double (*callback_type)(double, void*); +typedef double (*solver_type)(callback_type, double, double, double, double, + int, void *, scipy_zeros_info*); + +extern double bisect(callback_type f, double xa, double xb, double xtol, + double rtol, int iter, void *func_data_param, + scipy_zeros_info *solver_stats); +extern double ridder(callback_type f, double xa, double xb, double xtol, + double rtol, int iter, void *func_data_param, + scipy_zeros_info *solver_stats); +extern double brenth(callback_type f, double xa, double xb, double xtol, + double rtol, int iter, void *func_data_param, + scipy_zeros_info *solver_stats); +extern double brentq(callback_type f, double xa, double xb, double xtol, + double rtol, int iter, void *func_data_param, + scipy_zeros_info *solver_stats); + +#endif