Skip to content

Commit

Permalink
more ivp draft
Browse files Browse the repository at this point in the history
  • Loading branch information
s-m-e committed Jan 9, 2024
1 parent 942df19 commit d63448d
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 0 deletions.
130 changes: 130 additions & 0 deletions src/hapsira/core/math/ivp/_brentq.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/* Written by Charles Harris [email protected] */

#include <math.h>
#include "zeros.h"

#define MIN(a, b) ((a) < (b) ? (a) : (b))

/*
At the top of the loop the situation is the following:
1. the root is bracketed between xa and xb
2. xa is the most recent estimate
3. xp is the previous estimate
4. |fp| < |fb|
The order of xa and xp doesn't matter, but assume xp < xb. Then xa lies to
the right of xp and the assumption is that xa is increasing towards the root.
In this situation we will attempt quadratic extrapolation as long as the
condition
* |fa| < |fp| < |fb|
is satisfied. That is, the function value is decreasing as we go along.
Note the 4 above implies that the right inequlity already holds.
The first check is that xa is still to the left of the root. If not, xb is
replaced by xp and the interval reverses, with xb < xa. In this situation
we will try linear interpolation. That this has happened is signaled by the
equality xb == xp;
The second check is that |fa| < |fb|. If this is not the case, we swap
xa and xb and resort to bisection.
*/

double
brentq(callback_type f, double xa, double xb, double xtol, double rtol,
int iter, void *func_data_param, scipy_zeros_info *solver_stats)
{
double xpre = xa, xcur = xb;
double xblk = 0., fpre, fcur, fblk = 0., spre = 0., scur = 0., sbis;
/* the tolerance is 2*delta */
double delta;
double stry, dpre, dblk;
int i;
solver_stats->error_num = INPROGRESS;

fpre = (*f)(xpre, func_data_param);
fcur = (*f)(xcur, func_data_param);
solver_stats->funcalls = 2;
if (fpre == 0) {
solver_stats->error_num = CONVERGED;
return xpre;
}
if (fcur == 0) {
solver_stats->error_num = CONVERGED;
return xcur;
}
if (signbit(fpre)==signbit(fcur)) {
solver_stats->error_num = SIGNERR;
return 0.;
}
solver_stats->iterations = 0;
for (i = 0; i < iter; i++) {
solver_stats->iterations++;
if (fpre != 0 && fcur != 0 &&
(signbit(fpre) != signbit(fcur))) {
xblk = xpre;
fblk = fpre;
spre = scur = xcur - xpre;
}
if (fabs(fblk) < fabs(fcur)) {
xpre = xcur;
xcur = xblk;
xblk = xpre;

fpre = fcur;
fcur = fblk;
fblk = fpre;
}

delta = (xtol + rtol*fabs(xcur))/2;
sbis = (xblk - xcur)/2;
if (fcur == 0 || fabs(sbis) < delta) {
solver_stats->error_num = CONVERGED;
return xcur;
}

if (fabs(spre) > delta && fabs(fcur) < fabs(fpre)) {
if (xpre == xblk) {
/* interpolate */
stry = -fcur*(xcur - xpre)/(fcur - fpre);
}
else {
/* extrapolate */
dpre = (fpre - fcur)/(xpre - xcur);
dblk = (fblk - fcur)/(xblk - xcur);
stry = -fcur*(fblk*dblk - fpre*dpre)
/(dblk*dpre*(fblk - fpre));
}
if (2*fabs(stry) < MIN(fabs(spre), 3*fabs(sbis) - delta)) {
/* good short step */
spre = scur;
scur = stry;
} else {
/* bisect */
spre = sbis;
scur = sbis;
}
}
else {
/* bisect */
spre = sbis;
scur = sbis;
}

xpre = xcur; fpre = fcur;
if (fabs(scur) > delta) {
xcur += scur;
}
else {
xcur += (sbis > 0 ? delta : -delta);
}

fcur = (*f)(xcur, func_data_param);
solver_stats->funcalls++;
}
solver_stats->error_num = CONVERR;
return xcur;
}
132 changes: 132 additions & 0 deletions src/hapsira/core/math/ivp/_zeros.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Helper function that calls a Python function with extended arguments
*/

static PyObject *
call_solver(solver_type solver, PyObject *self, PyObject *args)
{
double a, b, xtol, rtol, zero;
int iter, fulloutput, disp=1, flag=0;
scipy_zeros_parameters params;
scipy_zeros_info solver_stats;
PyObject *f, *xargs;

if (!PyArg_ParseTuple(args, "OddddiOi|i",
&f, &a, &b, &xtol, &rtol, &iter, &xargs, &fulloutput, &disp)) {
PyErr_SetString(PyExc_RuntimeError, "Unable to parse arguments");
return NULL;
}
if (xtol < 0) {
PyErr_SetString(PyExc_ValueError, "xtol must be >= 0");
return NULL;
}
if (iter < 0) {
PyErr_SetString(PyExc_ValueError, "maxiter should be > 0");
return NULL;
}

params.function = f;
params.xargs = xargs;

if (!setjmp(params.env)) {
/* direct return */
solver_stats.error_num = 0;
zero = solver(scipy_zeros_functions_func, a, b, xtol, rtol,
iter, (void*)&params, &solver_stats);
} else {
/* error return from Python function */
return NULL;
}

if (solver_stats.error_num != CONVERGED) {
if (solver_stats.error_num == SIGNERR) {
PyErr_SetString(PyExc_ValueError,
"f(a) and f(b) must have different signs");
return NULL;
}
if (solver_stats.error_num == CONVERR) {
if (disp) {
char msg[100];
PyOS_snprintf(msg, sizeof(msg),
"Failed to converge after %d iterations.",
solver_stats.iterations);
PyErr_SetString(PyExc_RuntimeError, msg);
return NULL;
}
flag = CONVERR;
}
}
else {
flag = CONVERGED;
}
if (fulloutput) {
return Py_BuildValue("diii",
zero, solver_stats.funcalls, solver_stats.iterations, flag);
}
else {
return Py_BuildValue("d", zero);
}
}

/*
* These routines interface with the solvers through call_solver
*/

static PyObject *
_bisect(PyObject *self, PyObject *args)
{
return call_solver(bisect,self,args);
}

static PyObject *
_ridder(PyObject *self, PyObject *args)
{
return call_solver(ridder,self,args);
}

static PyObject *
_brenth(PyObject *self, PyObject *args)
{
return call_solver(brenth,self,args);
}

static PyObject *
_brentq(PyObject *self, PyObject *args)
{
return call_solver(brentq,self,args);
}

/*
* Standard Python module interface
*/

static PyMethodDef
Zerosmethods[] = {
{"_bisect", _bisect, METH_VARARGS, "a"},
{"_ridder", _ridder, METH_VARARGS, "a"},
{"_brenth", _brenth, METH_VARARGS, "a"},
{"_brentq", _brentq, METH_VARARGS, "a"},
{NULL, NULL}
};

static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_zeros",
NULL,
-1,
Zerosmethods,
NULL,
NULL,
NULL,
NULL
};

PyMODINIT_FUNC
PyInit__zeros(void)
{
PyObject *m;

m = PyModule_Create(&moduledef);

return m;
}
40 changes: 40 additions & 0 deletions src/hapsira/core/math/ivp/_zeros.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Written by Charles Harris [email protected] */

/* Modified to not depend on Python everywhere by Travis Oliphant.
*/

#ifndef ZEROS_H
#define ZEROS_H

typedef struct {
int funcalls;
int iterations;
int error_num;
} scipy_zeros_info;


/* Must agree with _ECONVERGED, _ESIGNERR, _ECONVERR in zeros.py */
#define CONVERGED 0
#define SIGNERR -1
#define CONVERR -2
#define EVALUEERR -3
#define INPROGRESS 1

typedef double (*callback_type)(double, void*);
typedef double (*solver_type)(callback_type, double, double, double, double,
int, void *, scipy_zeros_info*);

extern double bisect(callback_type f, double xa, double xb, double xtol,
double rtol, int iter, void *func_data_param,
scipy_zeros_info *solver_stats);
extern double ridder(callback_type f, double xa, double xb, double xtol,
double rtol, int iter, void *func_data_param,
scipy_zeros_info *solver_stats);
extern double brenth(callback_type f, double xa, double xb, double xtol,
double rtol, int iter, void *func_data_param,
scipy_zeros_info *solver_stats);
extern double brentq(callback_type f, double xa, double xb, double xtol,
double rtol, int iter, void *func_data_param,
scipy_zeros_info *solver_stats);

#endif

0 comments on commit d63448d

Please sign in to comment.