diff --git a/tests/unit_tests/config/test_transfer_functions.py b/tests/unit_tests/config/test_transfer_functions.py index 23c0acf0108..e2735e3f6f9 100644 --- a/tests/unit_tests/config/test_transfer_functions.py +++ b/tests/unit_tests/config/test_transfer_functions.py @@ -43,15 +43,15 @@ def test_that_truncated_normal_is_monotonic(x1, x2, arg): result1 = TransferFunction.trans_truncated_normal(x1, arg) result2 = TransferFunction.trans_truncated_normal(x2, arg) - if x1 < x2: + if np.isclose(x1, x2): + assert np.isclose(result1, result2) + elif x1 < x2: # Results should be different unless clamped assert ( result1 < result2 or (result1 == arg[2] and result2 == arg[2]) or (result1 == arg[3] and result2 == arg[3]) ) - elif x1 == x2: - assert result1 == result2 @given(valid_params())