diff --git a/movement_primitives/utils.py b/movement_primitives/utils.py index 04ddec2..469ec53 100644 --- a/movement_primitives/utils.py +++ b/movement_primitives/utils.py @@ -2,10 +2,34 @@ def ensure_1d_array(value, n_dims, var_name): - """Process scalar or array-like input to ensure it is a 1D numpy array of the correct shape.""" - value = np.atleast_1d(value).astype(float).flatten() - if value.shape[0] == 1: - value = value * np.ones(n_dims) - elif value.shape != (n_dims,): - raise ValueError(f"{var_name} has incorrect shape, expected ({n_dims},) got {value.shape}") + """Process scalar or array-like input to ensure it is a 1D numpy array. + + Parameters + ---------- + value : float or array-like, shape (n_dims,) + Argument to be processed. + + n_dims : int + Expected length of the 1d array. + + var_name : str + Name of the variable in case an exception has to be raised. + + Returns + ------- + value : array, shape (n_dims,) + 1D numpy array with dtype float. + + Raises + ------ + ValueError + If the argument is not compatible. + """ + value = np.atleast_1d(value).astype(float) + if value.ndim == 1 and value.shape[0] == 1: + value = np.repeat(value, n_dims) + if value.ndim > 1 or value.shape[0] != n_dims: + raise ValueError( + f"{var_name} has incorrect shape, expected ({n_dims},) " + f"got {value.shape}") return value diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..af776a5 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,30 @@ +import numpy as np +import pytest + +from movement_primitives.utils import ensure_1d_array + + +def test_ensure_1d_array_float(): + a = ensure_1d_array(5.0, 6, "a") + assert a.ndim == 1 + assert a.shape[0] == 6 + + +def test_ensure_1d_array(): + a = ensure_1d_array(np.ones(7), 7, "a") + assert a.ndim == 1 + assert a.shape[0] == 7 + + +def test_ensure_1d_array_wrong_size(): + with pytest.raises( + ValueError, + match=r"a has incorrect shape, expected \(8,\) got \(7,\)"): + ensure_1d_array(np.ones(7), 8, "a") + + +def test_ensure_1d_array_wrong_shape(): + with pytest.raises( + ValueError, + match=r"a has incorrect shape, expected \(7,\) got \(1, 7\)"): + ensure_1d_array(np.ones((1, 7)), 7, "a")