Skip to content

Commit

Permalink
TL: added fEval feature to LagrangeApproximation
Browse files Browse the repository at this point in the history
  • Loading branch information
tlunet committed Jun 21, 2024
1 parent 819cc5f commit 5542b8e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies:
...
- pip
- pip:
qmat
- qmat
```
## Install from source
Expand Down
21 changes: 20 additions & 1 deletion qmat/lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class LagrangeApproximation(object):
- 'MAX' : scaling based on the maximum weight value.
The default is 'MAX'.
fValues : list, tuple or np.1darray
Function values to be used when evaluating the LagrangeApproximation as a function
Attributes
----------
Expand All @@ -118,7 +120,7 @@ class LagrangeApproximation(object):
"Barycentric Lagrange interpolation." SIAM review, 46(3), 501-517.
"""

def __init__(self, points, weightComputation='AUTO', scaleRef='MAX'):
def __init__(self, points, weightComputation='AUTO', scaleRef='MAX', fValues=None):
points = np.asarray(points).ravel()
assert np.unique(points).size == points.size, "distinct interpolation points are required"

Expand Down Expand Up @@ -178,6 +180,23 @@ def chebfun(diffs):
self.weights = weights
self.weightComputation = weightComputation

# Store function values if provided
if fValues is not None:
fValues = np.asarray(fValues)
if fValues.shape != points.shape:
raise ValueError(f'fValues {fValues.shape} has not the correct shape: {points.shape}')
self.fValues = fValues


def __call__(self, t, fValues=None):
if fValues is None: fValues=self.fValues
assert fValues is not None, "cannot evaluate polynomial without fValues"
t = np.asarray(t)
fValues = np.asarray(fValues)
values = self.getInterpolationMatrix(t.ravel()).dot(fValues)
values.shape = t.shape
return values


@property
def n(self):
Expand Down
25 changes: 22 additions & 3 deletions tests/test_2_lagrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,28 @@ def testInterpolation(nNodes, weightComputation):
P = approx.getInterpolationMatrix(times)

polyCoeffs = np.random.rand(nNodes)
polyNodes = np.polyval(polyCoeffs, nodes)
polyTimes = np.polyval(polyCoeffs, times)
assert np.allclose(polyTimes, P @ polyNodes)
polyValues = np.polyval(polyCoeffs, nodes)
refEvals = np.polyval(polyCoeffs, times)
assert np.allclose(refEvals, P @ polyValues)


@pytest.mark.parametrize("weightComputation", ["AUTO", "FAST", "STABLE", "CHEBFUN"])
@pytest.mark.parametrize("nNodes", nNodeTests)
def testEvaluation(nNodes, weightComputation):
nodes = np.sort(np.random.rand(nNodes))
times = np.random.rand(nNodes*2)
polyCoeffs = np.random.rand(nNodes)
polyValues = np.polyval(polyCoeffs, nodes)

approx = LagrangeApproximation(nodes, weightComputation=weightComputation, fValues=polyValues)
P = approx.getInterpolationMatrix(times)
refEvals = P @ polyValues

polyEvals = approx(t=times, fValues=polyValues)
assert np.allclose(polyEvals, refEvals)

polyEvals = approx(t=times)
assert np.allclose(polyEvals, refEvals)


@pytest.mark.parametrize("numQuad", ["LEGENDRE_NUMPY", "LEGENDRE_SCIPY", "FEJER"])
Expand Down

0 comments on commit 5542b8e

Please sign in to comment.