From 8cde00d9de5469d5b2b6662111ac7a2f609a2246 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 4 Jun 2024 20:45:24 -0700 Subject: [PATCH 1/5] ENH: _restore_axes accomodates for negative axis arguments --- derivative/differentiation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/derivative/differentiation.py b/derivative/differentiation.py index e633c68..dbf8657 100644 --- a/derivative/differentiation.py +++ b/derivative/differentiation.py @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra return dX.flatten() else: # order of operations coupled with _align_axes - extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis) + orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args + extra_dims = tuple( + length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis + ) moved_shape = (orig_shape[axis],) + extra_dims dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis) return dX From 2bfeb89b41b744e98565f4fb12aa8f6c028b505e Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 4 Jun 2024 22:08:21 -0700 Subject: [PATCH 2/5] TST: added a test for negative axis arguments --- tests/test_interface.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_interface.py b/tests/test_interface.py index 49a88fc..b70c7f8 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -98,4 +98,12 @@ def test_hyperparam_entrypoint(): func = utils._load_hyperparam_func("kalman.default") expected = 1 result = func(None, None) - assert result == expected \ No newline at end of file + assert result == expected + + +def test_negative_axis(): + t = np.arange(3) + x = np.ones((2, 3, 2)) + axis = -2 + dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1) + assert x.shape == dx.shape \ No newline at end of file From b3002b69e139ee798ee0adbd4adf48862e8c80f1 Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 18 Jun 2024 12:04:05 -0700 Subject: [PATCH 3/5] TST: test makes sure that negative axis arguments differentiates the correct axis --- tests/test_interface.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_interface.py b/tests/test_interface.py index b70c7f8..9af7c00 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -103,7 +103,10 @@ def test_hyperparam_entrypoint(): def test_negative_axis(): t = np.arange(3) - x = np.ones((2, 3, 2)) + x = np.random.random(size=(2, 3, 2)) + x[1, :, 1] = 1 axis = -2 + expected = np.zeros(3) dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1) - assert x.shape == dx.shape \ No newline at end of file + assert x.shape == dx.shape + np.testing.assert_array_almost_equal(dx[1, :, 1], expected) \ No newline at end of file From 6f2053f988e4b945d671285f6f9aa7b1d652e3ad Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 18 Jun 2024 12:04:52 -0700 Subject: [PATCH 4/5] BLD: added importlib-metadata to pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 238ebce..cab0c7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,9 @@ numpy = "^1.18.3" scipy = "^1.4.1" scikit-learn = "^1" +# third-party access to the functionality of importlib.metadata +importlib-metadata = "^7.1.0" + # docs sphinx = {version = "^5", optional = true} nbsphinx = {version = "^0.6.1", optional = true} From 819f71003d6b033a3f95916fc2722681c619432f Mon Sep 17 00:00:00 2001 From: Yash Bhangale Date: Tue, 18 Jun 2024 14:45:14 -0700 Subject: [PATCH 5/5] CLN: added EOF line and removed comments --- pyproject.toml | 2 -- tests/test_interface.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cab0c7e..5c709d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,6 @@ python = "^3.9" numpy = "^1.18.3" scipy = "^1.4.1" scikit-learn = "^1" - -# third-party access to the functionality of importlib.metadata importlib-metadata = "^7.1.0" # docs diff --git a/tests/test_interface.py b/tests/test_interface.py index 9af7c00..5c75a80 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -109,4 +109,4 @@ def test_negative_axis(): expected = np.zeros(3) dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1) assert x.shape == dx.shape - np.testing.assert_array_almost_equal(dx[1, :, 1], expected) \ No newline at end of file + np.testing.assert_array_almost_equal(dx[1, :, 1], expected)