Skip to content

Commit

Permalink
Test check_1d_array_length
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFabisch committed Sep 19, 2024
1 parent c471f6b commit 6147307
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/test_dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def test_temporal_scaling():
assert np.linalg.norm(Y2 - Y4[::2]) / len(Y2) < 1e-3


def test_dmp_configure_invalid_input():
start_yd = np.array([0.0])
goal_y = np.array([1.0])

sd = DMP(n_dims=1, execution_time=1.0, dt=0.01)
with pytest.raises(
ValueError,
match=r"Expected start_y with 1 element, got 2."):
sd.configure(start_y=np.zeros(2), goal_y=goal_y, start_yd=start_yd)


def test_n_weights():
dmp = DMP(n_dims=5, n_weights_per_dim=9)
assert dmp.n_weights == 45
Expand Down
20 changes: 19 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from movement_primitives.utils import ensure_1d_array
from movement_primitives.utils import ensure_1d_array, check_1d_array_length


def test_ensure_1d_array_float():
Expand All @@ -28,3 +28,21 @@ def test_ensure_1d_array_wrong_shape():
ValueError,
match=r"a has incorrect shape, expected \(7,\) got \(1, 7\)"):
ensure_1d_array(np.ones((1, 7)), 7, "a")


def test_check_1d_array_length_correct():
check_1d_array_length([0, 1], "a", 2)


def test_check_1d_array_length_2_vs_1():
with pytest.raises(
ValueError,
match=r"Expected a with 1 element, got 2."):
check_1d_array_length([0, 2], "a", 1)


def test_check_1d_array_length_1_vs_2():
with pytest.raises(
ValueError,
match=r"Expected b with 2 elements, got 1."):
check_1d_array_length([0], "b", 2)

0 comments on commit 6147307

Please sign in to comment.