diff --git a/src/jacobi/_jacobi.py b/src/jacobi/_jacobi.py index 4f06e07..2658bc0 100644 --- a/src/jacobi/_jacobi.py +++ b/src/jacobi/_jacobi.py @@ -64,7 +64,7 @@ def jacobi( """ if diagonal: # TODO maybe solve this without introducing a wrapper function - return jacobi( + j, je = jacobi( lambda dx, x, *args: fn(x + dx, *args), 0, x, @@ -76,6 +76,10 @@ def jacobi( step=step, diagnostic=diagnostic, ) + if mask is not None: + j[~mask] = 0.0 + je[~mask] = 0.0 + return j, je if maxiter <= 0: raise ValueError("maxiter must be > 0") diff --git a/src/jacobi/_propagate.py b/src/jacobi/_propagate.py index aac8851..e72cc71 100644 --- a/src/jacobi/_propagate.py +++ b/src/jacobi/_propagate.py @@ -178,16 +178,33 @@ def _propagate_independent( ): ycov: Union[float, np.ndarray] = 0 + mask = kwargs.get("mask", None) + mask_parts = [] + if mask is None: + kwargs2 = kwargs + else: + kwargs2 = kwargs.copy() + for i, x in enumerate(x_parts): + # this fails if mask is not indexable, but mask is always an array + if np.shape(x) == np.shape(mask[i]): + mask_parts.append(mask[i]) + elif np.shape(x) == np.shape(mask): + mask_parts.append(mask) + else: + raise ValueError("mask shapes do not match arguments") + for i, x in enumerate(x_parts): - rest = x_parts[:i] + x_parts[i + 1 :] def wrapped(x): - args = rest[:i] + [x] + rest[i:] + args = x_parts[:i] + [x] + x_parts[i + 1 :] return fn(*args) xcov = xcov_parts[i] - yc = _propagate(wrapped, y, x, xcov)[1] + if mask_parts: + kwargs2["mask"] = mask_parts[i] + + yc = _propagate(wrapped, y, x, xcov, **kwargs2)[1] if np.ndim(ycov) == 2 and yc.ndim == 1: for i, yci in enumerate(yc): ycov[i, i] += yci # type:ignore diff --git a/test/test_propagate.py b/test/test_propagate.py index 4ca395f..09ee47f 100644 --- a/test/test_propagate.py +++ b/test/test_propagate.py @@ -2,6 +2,7 @@ from numpy.testing import assert_allclose from jacobi import propagate, jacobi import pytest +from numpy.testing import assert_equal def test_00(): @@ -218,7 +219,7 @@ def fn(x): assert_allclose(ycov, ycov_ref) -def test_on_nan(): +def test_on_nan_1(): def fn(x): return x**2 + 1 @@ -244,3 +245,66 @@ def fn(x): y2, ycov2 = propagate(fn, x, xcov) assert_allclose(y2, y_ref) assert_allclose(ycov2, ycov_ref) + + +def test_on_nan_2(): + nan = np.nan + a = np.array([4303.16536081, nan, 2586.42395464, nan, 2010.31141544, nan, nan, nan]) + a_var = np.array( + [7.89977628e04, nan, 1.87676043e22, nan, 8.70294972e04, nan, nan, nan] + ) + b = np.array([0.48358779, 0.0, 0.29371395, 0.0, 0.29838083, 0.58419942, 0.0, 0.0]) + b_var = np.array( + [ + 2.31907643e-05, + 0.00000000e00, + 2.17812131e-05, + 0.00000000e00, + 2.82526004e-05, + 1.66067899e-03, + 0.00000000e00, + 0.00000000e00, + ] + ) + + def f(a, b): + return a * b + + c, c_var = propagate(f, a, a_var, b, b_var, diagonal=True) + + mask = np.isnan(a) | np.isnan(b) + mask_var = mask | np.isnan(a_var) | np.isnan(b_var) + assert_equal(np.isnan(c), mask) + assert_equal(np.isnan(c_var), mask_var) + + +def test_mask_on_binary_function_1(): + a = np.array([1.0, 2.0]) + a_var = 0.01 * a + b = np.array([3.0, 4.0]) + b_var = 0.01 * b + + def f(a, b): + return a * b + + mask = [False, True] + c, c_var = propagate(f, a, a_var, b, b_var, mask=mask) + + assert c_var[0] == 0 + assert c_var[1] > 0 + + +def test_mask_on_binary_function_2(): + a = np.array([1.0, 2.0]) + a_var = 0.01 * a + b = np.array([3.0, 4.0, 5.0]) + b_var = 0.01 * b + + def f(a, b): + return np.outer(a, b).ravel() + + mask = [[False, True], [True, False, True]] + c1, c1_var = propagate(f, a, a_var, b, b_var, mask=mask) + c2, c2_var = propagate(f, a, a_var, b, b_var) + + assert np.sum(np.diag(c2_var) > np.diag(c1_var)) > 0