diff --git a/tests/transforms/test_iirfilter.py b/tests/transforms/test_iirfilter.py index cf280be..70f6fc4 100644 --- a/tests/transforms/test_iirfilter.py +++ b/tests/transforms/test_iirfilter.py @@ -14,10 +14,27 @@ chirp_mass_and_mass_ratio_to_components, ) -low_cutoff = 100 -high_cutoff = 20 -filters = ["cheby1", "cheby2", "ellip", "bessel", "butter"] -rprs = [(0.5, None), (None, 20), (0.5, 20), (None, None), (None, None)] + +@pytest.fixture +def low_cutoff(): + return 100 + + +@pytest.fixture +def high_cutoff(): + return 100 + + +@pytest.fixture( + params=[(0.5, None), (None, 20), (0.5, 20), (None, None), (None, None)] +) +def rpr(request): + return request.param + + +@pytest.fixture(params=["cheby1", "cheby2", "ellip", "bessel", "butter"]) +def filter(request): + return request.param @pytest.fixture(params=[256, 512, 1024, 2048]) @@ -30,7 +47,9 @@ def order(request): return request.param -def test_filters_synthetic_signal(sample_rate, order): +def test_filters_synthetic_signal( + sample_rate, order, filter, low_cutoff, high_cutoff, rprs +): t = np.linspace(0, 1.0, sample_rate, endpoint=False) tone_freq = 50 noise_amplitude = 0.5 @@ -41,88 +60,88 @@ def test_filters_synthetic_signal(sample_rate, order): slice_length = int(0.15 * sample_rate) - for ftype, (rp, rs) in zip(filters, rprs): - b, a = iirfilter( - order, - low_cutoff, - btype="low", - analog=False, - output="ba", - fs=sample_rate, - rp=rp, - rs=rs, - ftype=ftype, - ) - scipy_filtered_data_low = filtfilt(b, a, combined_signal)[ - slice_length:-slice_length - ] - - b, a = iirfilter( - order, - high_cutoff, - btype="high", - analog=False, - output="ba", - fs=sample_rate, - rp=rp, - rs=rs, - ftype=ftype, - ) - scipy_filtered_data_high = filtfilt(b, a, combined_signal)[ - slice_length:-slice_length - ] - - # test one of these with a tensor input instead of scalar Wn, rs, rps - torch_filtered_data_low = IIRFilter( - order, - torch.tensor(low_cutoff), - btype="low", - analog=False, - fs=sample_rate, - rs=torch.tensor(rs) if rs is not None else None, - rp=torch.tensor(rp) if rp is not None else None, - ftype=ftype, - )(torch.tensor(combined_signal).repeat(10, 1))[ - :, slice_length:-slice_length - ].numpy() - - torch_filtered_data_high = IIRFilter( - order, - high_cutoff, - btype="high", - analog=False, - fs=sample_rate, - rs=rs, - rp=rp, - ftype=ftype, - )(torch.tensor(combined_signal).repeat(10, 1))[ - :, slice_length:-slice_length - ].numpy() - - # test batch processing - for i in range(9): - assert np.allclose( - torch_filtered_data_low[0], - torch_filtered_data_low[i + 1], - atol=float(np.finfo(float).eps), - ) - assert np.allclose( - torch_filtered_data_high[0], - torch_filtered_data_high[i + 1], - atol=float(np.finfo(float).eps), - ) - + rp, rs = rprs + b, a = iirfilter( + order, + low_cutoff, + btype="low", + analog=False, + output="ba", + fs=sample_rate, + rp=rp, + rs=rs, + ftype=filter, + ) + scipy_filtered_data_low = filtfilt(b, a, combined_signal)[ + slice_length:-slice_length + ] + + b, a = iirfilter( + order, + high_cutoff, + btype="high", + analog=False, + output="ba", + fs=sample_rate, + rp=rp, + rs=rs, + ftype=filter, + ) + scipy_filtered_data_high = filtfilt(b, a, combined_signal)[ + slice_length:-slice_length + ] + + # test one of these with a tensor input instead of scalar Wn, rs, rps + torch_filtered_data_low = IIRFilter( + order, + torch.tensor(low_cutoff), + btype="low", + analog=False, + fs=sample_rate, + rs=torch.tensor(rs) if rs is not None else None, + rp=torch.tensor(rp) if rp is not None else None, + ftype=filter, + )(torch.tensor(combined_signal).repeat(10, 1))[ + :, slice_length:-slice_length + ].numpy() + + torch_filtered_data_high = IIRFilter( + order, + high_cutoff, + btype="high", + analog=False, + fs=sample_rate, + rs=rs, + rp=rp, + ftype=filter, + )(torch.tensor(combined_signal).repeat(10, 1))[ + :, slice_length:-slice_length + ].numpy() + + # test batch processing + for i in range(9): assert np.allclose( - scipy_filtered_data_low, torch_filtered_data_low[0], - atol=2e-1, + torch_filtered_data_low[i + 1], + atol=float(np.finfo(float).eps), ) assert np.allclose( - scipy_filtered_data_high, torch_filtered_data_high[0], - atol=2e-1, + torch_filtered_data_high[i + 1], + atol=float(np.finfo(float).eps), ) + assert np.allclose( + scipy_filtered_data_low, + torch_filtered_data_low[0], + atol=2e-1, + ) + assert np.allclose( + scipy_filtered_data_high, + torch_filtered_data_high[0], + atol=2e-1, + ) + N_SAMPLES = 1 @@ -201,6 +220,8 @@ def f_ref(request): def test_filters_phenom_signal( sample_rate, order, + filter, + rprs, chirp_mass, mass_ratio, distance, @@ -266,84 +287,85 @@ def test_filters_phenom_signal( slice_length = int(0.15 * sample_rate) - for ftype, (rp, rs) in zip(filters, rprs): - b, a = iirfilter( - order, - low_cutoff, - btype="low", - analog=False, - output="ba", - fs=sample_rate, - rp=rp, - rs=rs, - ftype=ftype, - ) - - scipy_filtered_data_low = filtfilt(b, a, hp_lal)[ - slice_length:-slice_length - ] - - b, a = iirfilter( - order, - high_cutoff, - btype="high", - analog=False, - output="ba", - fs=sample_rate, - rp=rp, - rs=rs, - ftype=ftype, - ) - scipy_filtered_data_high = filtfilt(b, a, hp_lal)[ - slice_length:-slice_length - ] - - torch_filtered_data_low = IIRFilter( - order, - low_cutoff, - btype="low", - analog=False, - fs=sample_rate, - rs=rs, - rp=rp, - ftype=ftype, - )(torch.tensor(hp_lal).repeat(10, 1))[ - :, slice_length:-slice_length - ].numpy() - - torch_filtered_data_high = IIRFilter( - order, - high_cutoff, - btype="high", - analog=False, - fs=sample_rate, - rs=rs, - rp=rp, - ftype=ftype, - )(torch.tensor(hp_lal).repeat(10, 1))[ - :, slice_length:-slice_length - ].numpy() - - # test batch processing - for i in range(9): - assert np.allclose( - torch_filtered_data_low[0], - torch_filtered_data_low[i + 1], - atol=float(np.finfo(float).eps), - ) - assert np.allclose( - torch_filtered_data_high[0], - torch_filtered_data_high[i + 1], - atol=float(np.finfo(float).eps), - ) + rp, rs = rprs + + b, a = iirfilter( + order, + low_cutoff, + btype="low", + analog=False, + output="ba", + fs=sample_rate, + rp=rp, + rs=rs, + ftype=filter, + ) + scipy_filtered_data_low = filtfilt(b, a, hp_lal)[ + slice_length:-slice_length + ] + + b, a = iirfilter( + order, + high_cutoff, + btype="high", + analog=False, + output="ba", + fs=sample_rate, + rp=rp, + rs=rs, + ftype=filter, + ) + scipy_filtered_data_high = filtfilt(b, a, hp_lal)[ + slice_length:-slice_length + ] + + torch_filtered_data_low = IIRFilter( + order, + low_cutoff, + btype="low", + analog=False, + fs=sample_rate, + rs=rs, + rp=rp, + ftype=filter, + )(torch.tensor(hp_lal).repeat(10, 1))[ + :, slice_length:-slice_length + ].numpy() + + torch_filtered_data_high = IIRFilter( + order, + high_cutoff, + btype="high", + analog=False, + fs=sample_rate, + rs=rs, + rp=rp, + ftype=filter, + )(torch.tensor(hp_lal).repeat(10, 1))[ + :, slice_length:-slice_length + ].numpy() + + # test batch processing + for i in range(9): assert np.allclose( - 1e21 * scipy_filtered_data_low, - 1e21 * torch_filtered_data_low[0], - atol=7e-3, + torch_filtered_data_low[0], + torch_filtered_data_low[i + 1], + atol=float(np.finfo(float).eps), ) assert np.allclose( - 1e21 * scipy_filtered_data_high, - 1e21 * torch_filtered_data_high[0], - atol=7e-3, + torch_filtered_data_high[0], + torch_filtered_data_high[i + 1], + atol=float(np.finfo(float).eps), ) + + assert np.allclose( + 1e21 * scipy_filtered_data_low, + 1e21 * torch_filtered_data_low[0], + atol=7e-3, + ) + assert np.allclose( + 1e21 * scipy_filtered_data_high, + 1e21 * torch_filtered_data_high[0], + atol=7e-3, + )