Skip to content

Commit c43c1c5

Browse files
authored
Remove converter, replace by macro pragma to generate overloads (#43)
* remove some whitespace * remove `toNumContextProc` converter It can cause issues in some generic / template contexts. * add `genInterp` macro pragma to generate overloads for `InterpolatorType` This takes the place of the previously automagical `converter`. * [CI] replace nim 1.4 by 1.6 * [CI] see what 2.0.8 has to say * use `toNumContextProc`, generate adaptiveGauss manually Regarding adaptiveGauss see the added comment * [CI] try 1.6 again
1 parent bd3c612 commit c43c1c5

File tree

4 files changed

+142
-31
lines changed

4 files changed

+142
-31
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
runs-on: ubuntu-latest
2626
strategy:
2727
matrix:
28-
nim: [ '1.4.x', 'stable', 'devel' ]
28+
nim: [ '1.6.x', 'stable', 'devel' ]
2929
# Steps represent a sequence of tasks that will be executed as part of the job
3030
name: Nim ${{ matrix.nim }} sample
3131
steps:

src/numericalnim/integrate.nim

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@ import arraymancer
66

77
from ./interpolate import InterpolatorType, newHermiteSpline
88

9+
# to annotate procedures with `{.genInterp.}` to generate `InterpolatorType` overloads
10+
import private/macro_utils
11+
912
## # Integration
1013
## This module implements various integration routines.
1114
## It provides:
12-
##
15+
##
1316
## ## Integrate discrete data:
1417
## - `trapz`, `simpson`: works for any spacing between points.
1518
## - `romberg`: requires equally spaced points and the number of points must be of the form 2^k + 1 ie 3, 5, 9, 17, 33, 65, 129 etc.
@@ -27,7 +30,7 @@ runnableExamples:
2730
## It also handles infinite integration limits.
2831
## - `gaussQuad`: Fixed step size Gaussian quadrature.
2932
## - `romberg`: Adaptive method based on Richardson Extrapolation.
30-
## - `adaptiveSimpson`: Adaptive step size.
33+
## - `adaptiveSimpson`: Adaptive step size.
3134
## - `simpson`: Fixed step size.
3235
## - `trapz`: Fixed step size.
3336

@@ -36,7 +39,7 @@ runnableExamples:
3639

3740
proc f(x: float, ctx: NumContext[float, float]): float =
3841
exp(x)
39-
42+
4043
let a = 0.0
4144
let b = Inf
4245
let integral = adaptiveGauss(f, a, b)
@@ -74,10 +77,9 @@ type
7477
IntervalList[T; U; V] = object
7578
list: seq[IntervalType[T, U, V]] # contains all the intervals sorted from smallest to largest error
7679

77-
7880
# N: #intervals
7981
proc trapz*[T](f: NumContextProc[T, float], xStart, xEnd: float,
80-
N = 500, ctx: NumContext[T, float] = nil): T =
82+
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
8183
## Calculate the integral of f using the trapezoidal rule.
8284
##
8385
## Input:
@@ -172,9 +174,8 @@ proc cumtrapz*[T](f: NumContextProc[T, float], X: openArray[float],
172174
t += dx
173175
result = hermiteInterpolate(X, times, y, dy)
174176
175-
176177
proc simpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
177-
N = 500, ctx: NumContext[T, float] = nil): T =
178+
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
178179
## Calculate the integral of f using Simpson's rule.
179180
##
180181
## Input:
@@ -252,7 +253,7 @@ proc simpson*[T](Y: openArray[T], X: openArray[float]): T =
252253
result += alpha * ySorted[2*i + 2] + beta * ySorted[2*i + 1] + eta * ySorted[2*i]
253254

254255
proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
255-
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
256+
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
256257
## Calculate the integral of f using an adaptive Simpson's rule.
257258
##
258259
## Input:
@@ -284,7 +285,7 @@ proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
284285
return left + right
285286
286287
proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: float,
287-
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T =
288+
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T {.genInterp.} =
288289
let zero = reused_points[0] - reused_points[0]
289290
let dx1 = (xEnd - xStart) / 2
290291
let dx2 = (xEnd - xStart) / 4
@@ -302,7 +303,7 @@ proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: floa
302303
return left + right
303304

304305
proc adaptiveSimpson2*[T](f: NumContextProc[T, float], xStart, xEnd: float,
305-
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
306+
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
306307
## Calculate the integral of f using an adaptive Simpson's rule.
307308
##
308309
## Input:
@@ -399,7 +400,7 @@ proc cumsimpson*[T](f: NumContextProc[T, float], X: openArray[float],
399400
result = hermiteInterpolate(X, t, ys, dy)
400401
401402
proc romberg*[T](f: NumContextProc[T, float], xStart, xEnd: float,
402-
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
403+
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
403404
## Calculate the integral of f using Romberg Integration.
404405
##
405406
## Input:
@@ -594,7 +595,7 @@ proc getGaussLegendreWeights(nPoints: int): tuple[nodes: seq[float], weights: se
594595
return gaussWeights[nPoints]
595596

596597
proc gaussQuad*[T](f: NumContextProc[T, float], xStart, xEnd: float,
597-
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T =
598+
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T {.genInterp.} =
598599
## Calculate the integral of f using Gaussian Quadrature.
599600
## Has 20 different sets of weights, ranging from 1 to 20 function evaluations per subinterval.
600601
##
@@ -654,7 +655,7 @@ proc calcGaussKronrod[T; U](f: NumContextProc[T, U], xStart, xEnd: U, ctx: NumCo
654655

655656

656657
proc adaptiveGaussLocal*[T](f: NumContextProc[T, float],
657-
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
658+
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
658659
## Calculate the integral of f using an locally adaptive Gauss-Kronrod Quadrature.
659660
##
660661
## Input:
@@ -872,6 +873,18 @@ proc adaptiveGauss*[T; U](f_in: NumContextProc[T, U],
872873
adaptiveGaussImpl()
873874
return totalValue
874875
876+
proc adaptiveGauss*[T](f_in: InterpolatorType[T]; xStart_in, xEnd_in: T;
877+
tol = 1e-8; initialPoints: openArray[T] = @[];
878+
maxintervals: int = 10000; ctx: NumContext[T, T] = nil): T =
879+
## NOTE: On Nim 2.0.8 we cannot use `{.genInterp.}` on the above proc, because of
880+
## of the double generic it has `[T; U]`. It fails. So this is just a manual version
881+
## of the generated code for the time being.
882+
mixin eval
883+
mixin InterpolatorType
884+
mixin toNumContextProc
885+
let ncp = toNumContextProc(f_in)
886+
adaptiveGauss(ncp, xStart_in, xEnd_in, tol, initialPoints, maxintervals, ctx)
887+
875888
proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
876889
xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): InterpolatorType[T] =
877890
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature. Inf and -Inf can be used as integration limits.
@@ -909,7 +922,10 @@ proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
909922
result = newHermiteSpline[T](xs, ys)
910923
911924
proc cumGauss*[T](f_in: NumContextProc[T, float],
912-
X: openArray[float], tol = 1e-8, initialPoints: openArray[float] = @[], maxintervals: int = 10000, ctx: NumContext[T, float] = nil): seq[T] =
925+
X: openArray[float], tol = 1e-8,
926+
initialPoints: openArray[float] = @[],
927+
maxintervals: int = 10000,
928+
ctx: NumContext[T, float] = nil): seq[T] {.genInterp.} =
913929
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature.
914930
## Returns a sequence of values which is the cumulative integral of f at the points defined in X.
915931
## Important: because of the much higher order of the Gauss-Kronrod quadrature (order 21) compared to the interpolating Hermite spline (order 3) you have to give it a large amount of initialPoints.

src/numericalnim/interpolate.nim

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ export rbf
1212
## This module implements various interpolation routines.
1313
## See also:
1414
## - `rbf module<rbf.html>`_ for RBF interpolation of scattered data in arbitrary dimensions.
15-
##
15+
##
1616
## ## 1D interpolation
1717
## - Hermite spline (recommended): cubic spline that works with many types of values. Accepts derivatives if available.
1818
## - Cubic spline: cubic spline that only works with `float`s.
1919
## - Linear spline: Linear spline that works with many types of values.
20-
##
20+
##
2121
## ### Extrapolation
2222
## Extrapolation is supported for all 1D interpolators by passing the type of extrapolation as an argument of `eval`.
2323
## The default is to use the interpolator's native method to extrapolate. This means that Linear does linear extrapolation,
@@ -26,7 +26,7 @@ export rbf
2626

2727
runnableExamples:
2828
import numericalnim, std/[math, sequtils]
29-
29+
3030
let x = linspace(0.0, 1.0, 10)
3131
let y = x.mapIt(sin(it))
3232

@@ -173,7 +173,7 @@ proc derivEval_cubicspline*[T](spline: InterpolatorType[T], x: float): T =
173173

174174
proc newCubicSpline*[T: SomeFloat](X: openArray[float], Y: openArray[
175175
T]): InterpolatorType[T] =
176-
## Returns a cubic spline.
176+
## Returns a cubic spline.
177177
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
178178
let coeffs = constructCubicSpline(xSorted, ySorted)
179179
result = InterpolatorType[T](X: xSorted, Y: ySorted, coeffs_T: coeffs, high: xSorted.high,
@@ -241,7 +241,7 @@ proc newHermiteSpline*[T](X: openArray[float], Y, dY: openArray[
241241
242242
proc newHermiteSpline*[T](X: openArray[float], Y: openArray[
243243
T]): InterpolatorType[T] =
244-
## Constructs a cubic Hermite spline by approximating the derivatives.
244+
## Constructs a cubic Hermite spline by approximating the derivatives.
245245
# if only (x, y) is given, use three-point difference to calculate dY.
246246
let (xSorted, ySorted) = sortAndTrimDataset(@X, @Y)
247247
var dySorted = newSeq[T](ySorted.len)
@@ -304,16 +304,16 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
304304
## - `Edge`: Use the value of the left/right edge.
305305
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
306306
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
307-
## - `Error`: Raises an `ValueError` if `x` is outside the range.
307+
## - `Error`: Raises an `ValueError` if `x` is outside the range.
308308
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
309-
##
309+
##
310310
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
311311
when U is Missing:
312312
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
313313
else:
314314
when not T is U:
315315
{.error: &"Type of `extrap` ({U}) is not the same as the type of the interpolator ({T})!".}
316-
316+
317317
let xLeft = x < interpolator.X[0]
318318
let xRight = x > interpolator.X[^1]
319319
if xLeft or xRight:
@@ -330,7 +330,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
330330
if xLeft: interpolator.Y[0]
331331
else: interpolator.Y[^1]
332332
of Linear:
333-
let (xs, ys) =
333+
let (xs, ys) =
334334
if xLeft:
335335
((interpolator.X[0], interpolator.X[1]), (interpolator.Y[0], interpolator.Y[1]))
336336
else:
@@ -341,7 +341,7 @@ proc eval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extrapolat
341341
raise newException(ValueError, &"x = {x} isn't in the interval [{interpolator.X[0]}, {interpolator.X[^1]}]")
342342

343343
result = interpolator.eval_handler(interpolator, x)
344-
344+
345345

346346
proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: ExtrapolateKind = Native, extrapValue: U = missing()): T =
347347
## Evaluates the derivative of an interpolator.
@@ -351,9 +351,9 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
351351
## - `Edge`: Use the value of the left/right edge.
352352
## - `Linear`: Uses linear extrapolation using the two points closest to the edge.
353353
## - `Native` (default): Uses the native method of the interpolator to extrapolate. For Linear1D it will be a linear extrapolation, and for Cubic and Hermite splines it will be cubic extrapolation.
354-
## - `Error`: Raises an `ValueError` if `x` is outside the range.
354+
## - `Error`: Raises an `ValueError` if `x` is outside the range.
355355
## - `extrapValue`: The extrapolation value to use when `extrap = Constant`.
356-
##
356+
##
357357
## > Beware: `Native` extrapolation for the cubic splines can very quickly diverge if the extrapolation is too far away from the interpolation points.
358358
when U is Missing:
359359
assert extrap != Constant, "When using `extrap = Constant`, a value `extrapValue` must be supplied!"
@@ -390,7 +390,7 @@ proc derivEval*[T; U](interpolator: InterpolatorType[T], x: float, extrap: Extra
390390
result = interpolator.deriveval_handler(interpolator, x)
391391

392392
proc eval*[T; U](spline: InterpolatorType[T], x: openArray[float], extrap: ExtrapolateKind = Native, extrapValue: U = missing()): seq[T] =
393-
## Evaluates an interpolator at all points in `x`.
393+
## Evaluates an interpolator at all points in `x`.
394394
result = newSeq[T](x.len)
395395
for i, xi in x:
396396
result[i] = eval(spline, xi, extrap, extrapValue)
@@ -399,7 +399,7 @@ proc toProc*[T](spline: InterpolatorType[T]): InterpolatorProc[T] =
399399
## Returns a proc to evaluate the interpolator.
400400
result = proc(x: float): T = eval(spline, x)
401401
402-
converter toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
402+
proc toNumContextProc*[T](spline: InterpolatorType[T]): NumContextProc[T, float] =
403403
## Convert interpolator to `NumContextProc`.
404404
result = proc(x: float, ctx: NumContext[T, float]): T = eval(spline, x)
405405
@@ -655,11 +655,11 @@ proc eval_barycentric2d*[T, U](self: InterpolatorUnstructured2DType[T, U]; x, y:
655655

656656
proc newBarycentric2D*[T: SomeFloat, U](points: Tensor[T], values: Tensor[U]): InterpolatorUnstructured2DType[T, U] =
657657
## Barycentric interpolation of scattered points in 2D.
658-
##
658+
##
659659
## Inputs:
660660
## - points: Tensor of shape (nPoints, 2) with the coordinates of all points.
661661
## - values: Tensor of shape (nPoints) with the function values.
662-
##
662+
##
663663
## Returns:
664664
## - Interpolator object that can be evaluated using `interp.eval(x, y`.
665665
assert points.rank == 2 and points.shape[1] == 2
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import std / macros
2+
proc checkArgNumContext(fn: NimNode) =
3+
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
4+
let params = fn.params
5+
# FormalParams <- `.params`
6+
# Ident "T"
7+
# IdentDefs <- `params[1]`
8+
# Sym "f"
9+
# BracketExpr <- `params[1][1]`
10+
# Sym "NumContextProc" <- `params[1][1][0]`
11+
# Ident "T"
12+
# Sym "float"
13+
# Empty
14+
expectKind params, nnkFormalParams
15+
expectKind params[1], nnkIdentDefs
16+
expectKind params[1][1], nnkBracketExpr
17+
expectKind params[1][1][0], {nnkSym, nnkIdent}
18+
if params[1][1][0].strVal != "NumContextProc":
19+
error("The function annotated with `{.genInterp.}` does not take a `NumContextProc` as the firs argument.")
20+
21+
proc replaceNumCtxArg(fn: NimNode): NimNode =
22+
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
23+
## MUST run `checkArgNumContext` on `fn` first.
24+
##
25+
## It returns the identifier of the first argument.
26+
var params = fn.params # see `checkArgNNumContext`
27+
expectKind params[1][0], {nnkSym, nnkIdent}
28+
result = ident(params[1][0].strVal)
29+
params[1] = nnkIdentDefs.newTree(
30+
result,
31+
nnkBracketExpr.newTree(
32+
ident"InterpolatorType",
33+
ident"T"
34+
),
35+
newEmptyNode()
36+
)
37+
fn.params = params
38+
39+
proc untype(n: NimNode): NimNode =
40+
case n.kind
41+
of nnkSym: result = ident(n.strVal)
42+
of nnkIdent: result = n
43+
else:
44+
error("Cannot untype the argument: " & $n.treerepr)
45+
46+
proc genOriginalCall(fn: NimNode, ncp: NimNode): NimNode =
47+
## Generates a call to the original procedure `fn` with `ncp`
48+
## as the first argument
49+
let fnName = fn.name
50+
let params = fn.params
51+
# extract all arguments we need to pass from `params`
52+
var p = newSeq[NimNode]()
53+
p.add ncp
54+
for i in 2 ..< params.len: # first param is return type, second is parameter we replace
55+
expectKind params[i], nnkIdentDefs
56+
if params[i].len in 0 .. 2:
57+
error("Invalid parameter: " & $params[i].treerepr)
58+
else: # one or more arg of this type
59+
# IdentDefs <- Example with 2 arguments of the same type
60+
# Ident "xStart" <- index `0`
61+
# Ident "xEnd" <- index `len - 3 = 4 - 3 = 1`
62+
# Ident "float"
63+
# Empty
64+
for j in 0 .. params[i].len - 3:
65+
p.add untype(params[i][j])
66+
# generate the call
67+
result = nnkCall.newTree(fnName)
68+
for el in p:
69+
result.add el
70+
71+
macro genInterp*(fn: untyped): untyped =
72+
## Takes a `proc` with a `NumContextProc` parameter as the first argument
73+
## and returns two procedures:
74+
## 1. The original proc
75+
## 2. An overload, which converts an `InterpolatorType[T]` argument to a
76+
## `NumContextProc[T, float]` using the conversion proc.
77+
doAssert fn.kind in {nnkProcDef, nnkFuncDef}
78+
result = newStmtList(fn)
79+
# 1. check arg
80+
checkArgNumContext(fn)
81+
# 2. generate overload
82+
var new = fn.copyNimTree()
83+
# 2a. replace first argument by `InterpolatorType[T]`
84+
let arg = new.replaceNumCtxArg()
85+
# 2b. add body with NumContextProc
86+
let ncpIdent = ident"ncp"
87+
new.body = quote do:
88+
mixin eval # defined in `interpolate`, but macro used in `integrate`
89+
mixin InterpolatorType
90+
mixin toNumContextProc
91+
let `ncpIdent` = toNumContextProc(`arg`)
92+
# 2c. add call to original proc
93+
new.body.add genOriginalCall(fn, ncpIdent)
94+
# 3. finalize
95+
result.add new

0 commit comments

Comments
 (0)