diff --git a/derivative/differentiation.py b/derivative/differentiation.py index e633c68..dbf8657 100644 --- a/derivative/differentiation.py +++ b/derivative/differentiation.py @@ -248,7 +248,10 @@ def _restore_axes(dX: NDArray, axis: int, orig_shape: tuple[int, ...]) -> NDArra return dX.flatten() else: # order of operations coupled with _align_axes - extra_dims = tuple(length for ax, length in enumerate(orig_shape) if ax != axis) + orig_diff_axis = range(len(orig_shape))[axis] # to handle negative axis args + extra_dims = tuple( + length for ax, length in enumerate(orig_shape) if ax != orig_diff_axis + ) moved_shape = (orig_shape[axis],) + extra_dims dX = np.moveaxis(dX.T.reshape((moved_shape)), 0, axis) return dX diff --git a/pyproject.toml b/pyproject.toml index 238ebce..5c709d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ python = "^3.9" numpy = "^1.18.3" scipy = "^1.4.1" scikit-learn = "^1" +importlib-metadata = "^7.1.0" # docs sphinx = {version = "^5", optional = true} diff --git a/tests/test_interface.py b/tests/test_interface.py index 49a88fc..5c75a80 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -98,4 +98,15 @@ def test_hyperparam_entrypoint(): func = utils._load_hyperparam_func("kalman.default") expected = 1 result = func(None, None) - assert result == expected \ No newline at end of file + assert result == expected + + +def test_negative_axis(): + t = np.arange(3) + x = np.random.random(size=(2, 3, 2)) + x[1, :, 1] = 1 + axis = -2 + expected = np.zeros(3) + dx = dxdt(x, t, kind='finite_difference', axis=axis, k=1) + assert x.shape == dx.shape + np.testing.assert_array_almost_equal(dx[1, :, 1], expected)