diff --git a/tests/conftest.py b/tests/conftest.py index 11df3ba..1b8e59d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ def compare_against_numpy(): of the time """ - def compare(value, expected): + def compare(value, expected, num_bad: int = 0): sigma = 0.01 prob = 0.99999 N = np.prod(expected.shape) @@ -28,7 +28,7 @@ def compare(value, expected): # TODO: eventually we should track down # and address the underlying cause - assert isclose.sum() - np.prod(isclose.shape) <= 1 + assert isclose.sum() - np.prod(isclose.shape) <= num_bad return compare diff --git a/tests/test_spectral.py b/tests/test_spectral.py index 0f834f7..91e6090 100644 --- a/tests/test_spectral.py +++ b/tests/test_spectral.py @@ -111,7 +111,7 @@ def test_fast_spectral_density( # that components higher than the first two are correct torch_result = torch_result[..., 2:] scipy_result = scipy_result[..., 2:] - compare_against_numpy(torch_result, scipy_result) + compare_against_numpy(torch_result, scipy_result, num_bad=1) # make sure we catch any calls with too many dimensions if ndim == 3: @@ -260,7 +260,7 @@ def test_fast_spectral_density_with_y( torch_result = torch_result[..., 2:] scipy_result = scipy_result[..., 2:] - compare_against_numpy(torch_result, scipy_result) + compare_against_numpy(torch_result, scipy_result, num_bad=1) _shape_checks(ndim, y_ndim, x, y, fsd) @@ -322,7 +322,7 @@ def test_spectral_density( window=signal.windows.hann(nperseg, False), average=average, ) - compare_against_numpy(torch_result, scipy_result) + compare_against_numpy(torch_result, scipy_result, num_bad=1) # make sure we catch any calls with too many dimensions if ndim == 3: