Skip to content

Commit

Permalink
Fixing NaN issue in propagate when more than one argument is passed (#32
Browse files Browse the repository at this point in the history
)
  • Loading branch information
HDembinski authored Dec 8, 2022
1 parent 1fdbe48 commit b889822
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/jacobi/_jacobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def jacobi(
"""
if diagonal:
# TODO maybe solve this without introducing a wrapper function
return jacobi(
j, je = jacobi(
lambda dx, x, *args: fn(x + dx, *args),
0,
x,
Expand All @@ -76,6 +76,10 @@ def jacobi(
step=step,
diagnostic=diagnostic,
)
if mask is not None:
j[~mask] = 0.0
je[~mask] = 0.0
return j, je

if maxiter <= 0:
raise ValueError("maxiter must be > 0")
Expand Down
23 changes: 20 additions & 3 deletions src/jacobi/_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,33 @@ def _propagate_independent(
):
ycov: Union[float, np.ndarray] = 0

mask = kwargs.get("mask", None)
mask_parts = []
if mask is None:
kwargs2 = kwargs
else:
kwargs2 = kwargs.copy()
for i, x in enumerate(x_parts):
# this fails if mask is not indexable, but mask is always an array
if np.shape(x) == np.shape(mask[i]):
mask_parts.append(mask[i])
elif np.shape(x) == np.shape(mask):
mask_parts.append(mask)
else:
raise ValueError("mask shapes do not match arguments")

for i, x in enumerate(x_parts):
rest = x_parts[:i] + x_parts[i + 1 :]

def wrapped(x):
args = rest[:i] + [x] + rest[i:]
args = x_parts[:i] + [x] + x_parts[i + 1 :]
return fn(*args)

xcov = xcov_parts[i]

yc = _propagate(wrapped, y, x, xcov)[1]
if mask_parts:
kwargs2["mask"] = mask_parts[i]

yc = _propagate(wrapped, y, x, xcov, **kwargs2)[1]
if np.ndim(ycov) == 2 and yc.ndim == 1:
for i, yci in enumerate(yc):
ycov[i, i] += yci # type:ignore
Expand Down
66 changes: 65 additions & 1 deletion test/test_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy.testing import assert_allclose
from jacobi import propagate, jacobi
import pytest
from numpy.testing import assert_equal


def test_00():
Expand Down Expand Up @@ -218,7 +219,7 @@ def fn(x):
assert_allclose(ycov, ycov_ref)


def test_on_nan():
def test_on_nan_1():
def fn(x):
return x**2 + 1

Expand All @@ -244,3 +245,66 @@ def fn(x):
y2, ycov2 = propagate(fn, x, xcov)
assert_allclose(y2, y_ref)
assert_allclose(ycov2, ycov_ref)


def test_on_nan_2():
nan = np.nan
a = np.array([4303.16536081, nan, 2586.42395464, nan, 2010.31141544, nan, nan, nan])
a_var = np.array(
[7.89977628e04, nan, 1.87676043e22, nan, 8.70294972e04, nan, nan, nan]
)
b = np.array([0.48358779, 0.0, 0.29371395, 0.0, 0.29838083, 0.58419942, 0.0, 0.0])
b_var = np.array(
[
2.31907643e-05,
0.00000000e00,
2.17812131e-05,
0.00000000e00,
2.82526004e-05,
1.66067899e-03,
0.00000000e00,
0.00000000e00,
]
)

def f(a, b):
return a * b

c, c_var = propagate(f, a, a_var, b, b_var, diagonal=True)

mask = np.isnan(a) | np.isnan(b)
mask_var = mask | np.isnan(a_var) | np.isnan(b_var)
assert_equal(np.isnan(c), mask)
assert_equal(np.isnan(c_var), mask_var)


def test_mask_on_binary_function_1():
a = np.array([1.0, 2.0])
a_var = 0.01 * a
b = np.array([3.0, 4.0])
b_var = 0.01 * b

def f(a, b):
return a * b

mask = [False, True]
c, c_var = propagate(f, a, a_var, b, b_var, mask=mask)

assert c_var[0] == 0
assert c_var[1] > 0


def test_mask_on_binary_function_2():
a = np.array([1.0, 2.0])
a_var = 0.01 * a
b = np.array([3.0, 4.0, 5.0])
b_var = 0.01 * b

def f(a, b):
return np.outer(a, b).ravel()

mask = [[False, True], [True, False, True]]
c1, c1_var = propagate(f, a, a_var, b, b_var, mask=mask)
c2, c2_var = propagate(f, a, a_var, b, b_var)

assert np.sum(np.diag(c2_var) > np.diag(c1_var)) > 0

0 comments on commit b889822

Please sign in to comment.