Skip to content

Commit

Permalink
Remove converter, replace by macro pragma to generate overloads (#43)
Browse files Browse the repository at this point in the history
* remove some whitespace

* remove `toNumContextProc` converter

It can cause issues in some generic / template contexts.

* add `genInterp` macro pragma to generate overloads for `InterpolatorType`

This takes the place of the previously automagical `converter`.

* [CI] replace nim 1.4 by 1.6

* [CI] see what 2.0.8 has to say

* use `toNumContextProc`, generate adaptiveGauss manually

Regarding adaptiveGauss see the added comment

* [CI] try 1.6 again
  • Loading branch information
Vindaar authored Sep 13, 2024
1 parent bd3c612 commit c43c1c5
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
nim: [ '1.4.x', 'stable', 'devel' ]
nim: [ '1.6.x', 'stable', 'devel' ]
# Steps represent a sequence of tasks that will be executed as part of the job
name: Nim ${{ matrix.nim }} sample
steps:
Expand Down
44 changes: 30 additions & 14 deletions src/numericalnim/integrate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import arraymancer

from ./interpolate import InterpolatorType, newHermiteSpline

# to annotate procedures with `{.genInterp.}` to generate `InterpolatorType` overloads
import private/macro_utils

## # Integration
## This module implements various integration routines.
## It provides:
##
##
## ## Integrate discrete data:
## - `trapz`, `simpson`: works for any spacing between points.
## - `romberg`: requires equally spaced points and the number of points must be of the form 2^k + 1 ie 3, 5, 9, 17, 33, 65, 129 etc.
Expand All @@ -27,7 +30,7 @@ runnableExamples:
## It also handles infinite integration limits.
## - `gaussQuad`: Fixed step size Gaussian quadrature.
## - `romberg`: Adaptive method based on Richardson Extrapolation.
## - `adaptiveSimpson`: Adaptive step size.
## - `adaptiveSimpson`: Adaptive step size.
## - `simpson`: Fixed step size.
## - `trapz`: Fixed step size.

Expand All @@ -36,7 +39,7 @@ runnableExamples:

proc f(x: float, ctx: NumContext[float, float]): float =
exp(x)

let a = 0.0
let b = Inf
let integral = adaptiveGauss(f, a, b)
Expand Down Expand Up @@ -74,10 +77,9 @@ type
IntervalList[T; U; V] = object
list: seq[IntervalType[T, U, V]] # contains all the intervals sorted from smallest to largest error


# N: #intervals
proc trapz*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using the trapezoidal rule.
##
## Input:
Expand Down Expand Up @@ -172,9 +174,8 @@ proc cumtrapz*[T](f: NumContextProc[T, float], X: openArray[float],
t += dx
result = hermiteInterpolate(X, times, y, dy)


proc simpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -252,7 +253,7 @@ proc simpson*[T](Y: openArray[T], X: openArray[float]): T =
result += alpha * ySorted[2*i + 2] + beta * ySorted[2*i + 1] + eta * ySorted[2*i]

proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -284,7 +285,7 @@ proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
return left + right

proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T =
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T {.genInterp.} =
let zero = reused_points[0] - reused_points[0]
let dx1 = (xEnd - xStart) / 2
let dx2 = (xEnd - xStart) / 4
Expand All @@ -302,7 +303,7 @@ proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: floa
return left + right

proc adaptiveSimpson2*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -399,7 +400,7 @@ proc cumsimpson*[T](f: NumContextProc[T, float], X: openArray[float],
result = hermiteInterpolate(X, t, ys, dy)

proc romberg*[T](f: NumContextProc[T, float], xStart, xEnd: float,
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Romberg Integration.
##
## Input:
Expand Down Expand Up @@ -594,7 +595,7 @@ proc getGaussLegendreWeights(nPoints: int): tuple[nodes: seq[float], weights: se
return gaussWeights[nPoints]

proc gaussQuad*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T =
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Gaussian Quadrature.
## Has 20 different sets of weights, ranging from 1 to 20 function evaluations per subinterval.
##
Expand Down Expand Up @@ -654,7 +655,7 @@ proc calcGaussKronrod[T; U](f: NumContextProc[T, U], xStart, xEnd: U, ctx: NumCo


proc adaptiveGaussLocal*[T](f: NumContextProc[T, float],
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an locally adaptive Gauss-Kronrod Quadrature.
##
## Input:
Expand Down Expand Up @@ -872,6 +873,18 @@ proc adaptiveGauss*[T; U](f_in: NumContextProc[T, U],
adaptiveGaussImpl()
return totalValue

proc adaptiveGauss*[T](f_in: InterpolatorType[T]; xStart_in, xEnd_in: T;
tol = 1e-8; initialPoints: openArray[T] = @[];
maxintervals: int = 10000; ctx: NumContext[T, T] = nil): T =
## NOTE: On Nim 2.0.8 we cannot use `{.genInterp.}` on the above proc, because of
## of the double generic it has `[T; U]`. It fails. So this is just a manual version
## of the generated code for the time being.
mixin eval
mixin InterpolatorType
mixin toNumContextProc
let ncp = toNumContextProc(f_in)
adaptiveGauss(ncp, xStart_in, xEnd_in, tol, initialPoints, maxintervals, ctx)

proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): InterpolatorType[T] =
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature. Inf and -Inf can be used as integration limits.
Expand Down Expand Up @@ -909,7 +922,10 @@ proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
result = newHermiteSpline[T](xs, ys)

proc cumGauss*[T](f_in: NumContextProc[T, float],
X: openArray[float], tol = 1e-8, initialPoints: openArray[float] = @[], maxintervals: int = 10000, ctx: NumContext[T, float] = nil): seq[T] =
X: openArray[float], tol = 1e-8,
initialPoints: openArray[float] = @[],
maxintervals: int = 10000,
ctx: NumContext[T, float] = nil): seq[T] {.genInterp.} =
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature.
## Returns a sequence of values which is the cumulative integral of f at the points defined in X.
## Important: because of the much higher order of the Gauss-Kronrod quadrature (order 21) compared to the interpolating Hermite spline (order 3) you have to give it a large amount of initialPoints.
Expand Down
32 changes: 16 additions & 16 deletions src/numericalnim/interpolate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ export rbf
## This module implements various interpolation routines.
## See also:
## - `rbf module<rbf.html>`_ for RBF interpolation of scattered data in arbitrary dimensions.
##
##
## ## 1D interpolation
## - Hermite spline (recommended): cubic spline that works with many types of values. Accepts derivatives if available.
## - Cubic spline: cubic spline that only works with `float`s.
## - Linear spline: Linear spline that works with many types of values.
##
##
## ### Extrapolation
## Extrapolation is supported for all 1D interpolators by passing the type of extrapolation as an argument of `eval`.
## The default is to use the interpolator's native method to extrapolate. This means that Linear does linear extrapolation,
Expand All @@ -26,7 +26,7 @@ export rbf

runnableExamples:
import numericalnim, std/[math, sequtils]

let x = linspace(0.0, 1.0, 10)
let y = x.mapIt(sin(it))

Expand Down Expand Up @@ -173,7 +173,7 @@ proc derivEval_cubicspline*[T](spline: InterpolatorType[T], x: float): T =

proc newCubicSpline*[T: SomeFloat](X: openArray[float], Y: openArray[
T]): InterpolatorType[T] =
## Returns a cubic spline.
## Returns a cubic spline.
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
let coeffs = constructCubicSpline(xSorted, ySorted)
result = InterpolatorType[T](X: xSorted, Y: ySorted, coeffs_T: coeffs, high: xSorted.high,
Expand Down Expand Up @@ -241,7 +241,7 @@ proc newHermiteSpline*[T](X: openArray[float], Y, dY: openArray[

proc newHermiteSpline*[T](X: openArray[float], Y: openArray[
T]): InterpolatorType[T] =
## Constructs a cubic Hermite spline by approximating the derivatives.
## Constructs a cubic Hermite spline by approximating the derivatives.
# if only (x, y) is given, use three-point difference to calculate dY.
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
var dySorted = newSeq[T](ySorted.len)
Expand Down Expand Up @@ -304,16 +304,16 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
## - `Edge`: Use the value of the left/right edge.
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
##
##
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
when U is Missing:
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
else:
when not T is U:
{.error: &"Type of `extrap` ({U}) is not the same as the type of the interpolator ({T})!".}

let xLeft = x < interpolator.X[0]
let xRight = x > interpolator.X[^1]
if xLeft or xRight:
Expand All @@ -330,7 +330,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
if xLeft: interpolator.Y[0]
else: interpolator.Y[^1]
of Linear:
let (xs, ys) =
let (xs, ys) =
if xLeft:
((interpolator.X[0], interpolator.X[1]), (interpolator.Y[0], interpolator.Y[1]))
else:
Expand All @@ -341,7 +341,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
raise newException(ValueError, &"x = {x} isn't in the interval [{interpolator.X[0]}, {interpolator.X[^1]}]")

result = interpolator.eval_handler(interpolator, x)


proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: ExtrapolateKind = Native, extrapValue: U = missing()): T =
## Evaluates the derivative of an interpolator.
Expand All @@ -351,9 +351,9 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
## - `Edge`: Use the value of the left/right edge.
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `Error`: Raises an `ValueError` if `x` is outside the range.
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
##
##
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
when U is Missing:
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
Expand Down Expand Up @@ -390,7 +390,7 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
result = interpolator.deriveval_handler(interpolator, x)

proc eval*[T; U](spline: InterpolatorType[T], x: openArray[float], extrap: ExtrapolateKind = Native, extrapValue: U = missing()): seq[T] =
## Evaluates an interpolator at all points in `x`.
## Evaluates an interpolator at all points in `x`.
result = newSeq[T](x.len)
for i, xi in x:
result[i] = eval(spline, xi, extrap, extrapValue)
Expand All @@ -399,7 +399,7 @@ proc toProc*[T](spline: InterpolatorType[T]): InterpolatorProc[T] =
## Returns a proc to evaluate the interpolator.
result = proc(x: float): T = eval(spline, x)

converter toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
proc toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
## Convert interpolator to `NumContextProc`.
result = proc(x: float, ctx: NumContext[T, float]): T = eval(spline, x)

Expand Down Expand Up @@ -655,11 +655,11 @@ proc eval_barycentric2d*[T, U](self: InterpolatorUnstructured2DType[T, U]; x, y:

proc newBarycentric2D*[T: SomeFloat, U](points: Tensor[T], values: Tensor[U]): InterpolatorUnstructured2DType[T, U] =
## Barycentric interpolation of scattered points in 2D.
##
##
## Inputs:
## - points: Tensor of shape (nPoints, 2) with the coordinates of all points.
## - values: Tensor of shape (nPoints) with the function values.
##
##
## Returns:
## - Interpolator object that can be evaluated using `interp.eval(x, y`.
assert points.rank == 2 and points.shape[1] == 2
Expand Down
95 changes: 95 additions & 0 deletions src/numericalnim/private/macro_utils.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import std / macros
proc checkArgNumContext(fn: NimNode) =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
let params = fn.params
# FormalParams <- `.params`
# Ident "T"
# IdentDefs <- `params[1]`
# Sym "f"
# BracketExpr <- `params[1][1]`
# Sym "NumContextProc" <- `params[1][1][0]`
# Ident "T"
# Sym "float"
# Empty
expectKind params, nnkFormalParams
expectKind params[1], nnkIdentDefs
expectKind params[1][1], nnkBracketExpr
expectKind params[1][1][0], {nnkSym, nnkIdent}
if params[1][1][0].strVal != "NumContextProc":
error("The function annotated with `{.genInterp.}` does not take a `NumContextProc` as the firs argument.")

proc replaceNumCtxArg(fn: NimNode): NimNode =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
## MUST run `checkArgNumContext` on `fn` first.
##
## It returns the identifier of the first argument.
var params = fn.params # see `checkArgNNumContext`
expectKind params[1][0], {nnkSym, nnkIdent}
result = ident(params[1][0].strVal)
params[1] = nnkIdentDefs.newTree(
result,
nnkBracketExpr.newTree(
ident"InterpolatorType",
ident"T"
),
newEmptyNode()
)
fn.params = params

proc untype(n: NimNode): NimNode =
case n.kind
of nnkSym: result = ident(n.strVal)
of nnkIdent: result = n
else:
error("Cannot untype the argument: " & $n.treerepr)

proc genOriginalCall(fn: NimNode, ncp: NimNode): NimNode =
## Generates a call to the original procedure `fn` with `ncp`
## as the first argument
let fnName = fn.name
let params = fn.params
# extract all arguments we need to pass from `params`
var p = newSeq[NimNode]()
p.add ncp
for i in 2 ..< params.len: # first param is return type, second is parameter we replace
expectKind params[i], nnkIdentDefs
if params[i].len in 0 .. 2:
error("Invalid parameter: " & $params[i].treerepr)
else: # one or more arg of this type
# IdentDefs <- Example with 2 arguments of the same type
# Ident "xStart" <- index `0`
# Ident "xEnd" <- index `len - 3 = 4 - 3 = 1`
# Ident "float"
# Empty
for j in 0 .. params[i].len - 3:
p.add untype(params[i][j])
# generate the call
result = nnkCall.newTree(fnName)
for el in p:
result.add el

macro genInterp*(fn: untyped): untyped =
## Takes a `proc` with a `NumContextProc` parameter as the first argument
## and returns two procedures:
## 1. The original proc
## 2. An overload, which converts an `InterpolatorType[T]` argument to a
## `NumContextProc[T, float]` using the conversion proc.
doAssert fn.kind in {nnkProcDef, nnkFuncDef}
result = newStmtList(fn)
# 1. check arg
checkArgNumContext(fn)
# 2. generate overload
var new = fn.copyNimTree()
# 2a. replace first argument by `InterpolatorType[T]`
let arg = new.replaceNumCtxArg()
# 2b. add body with NumContextProc
let ncpIdent = ident"ncp"
new.body = quote do:
mixin eval # defined in `interpolate`, but macro used in `integrate`
mixin InterpolatorType
mixin toNumContextProc
let `ncpIdent` = toNumContextProc(`arg`)
# 2c. add call to original proc
new.body.add genOriginalCall(fn, ncpIdent)
# 3. finalize
result.add new

0 comments on commit c43c1c5

Please sign in to comment.