Skip to content

Commit

Permalink
make filters a fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Feb 6, 2025
1 parent 555f119 commit ddceca3
Showing 1 changed file with 178 additions and 156 deletions.
334 changes: 178 additions & 156 deletions tests/transforms/test_iirfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,27 @@
chirp_mass_and_mass_ratio_to_components,
)

low_cutoff = 100
high_cutoff = 20
filters = ["cheby1", "cheby2", "ellip", "bessel", "butter"]
rprs = [(0.5, None), (None, 20), (0.5, 20), (None, None), (None, None)]

@pytest.fixture
def low_cutoff():
return 100


@pytest.fixture
def high_cutoff():
return 100


@pytest.fixture(
params=[(0.5, None), (None, 20), (0.5, 20), (None, None), (None, None)]
)
def rpr(request):
return request.param


@pytest.fixture(params=["cheby1", "cheby2", "ellip", "bessel", "butter"])
def filter(request):
return request.param


@pytest.fixture(params=[256, 512, 1024, 2048])
Expand All @@ -30,7 +47,9 @@ def order(request):
return request.param


def test_filters_synthetic_signal(sample_rate, order):
def test_filters_synthetic_signal(
sample_rate, order, filter, low_cutoff, high_cutoff, rprs
):
t = np.linspace(0, 1.0, sample_rate, endpoint=False)
tone_freq = 50
noise_amplitude = 0.5
Expand All @@ -41,88 +60,88 @@ def test_filters_synthetic_signal(sample_rate, order):

slice_length = int(0.15 * sample_rate)

for ftype, (rp, rs) in zip(filters, rprs):
b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_low = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_high = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

# test one of these with a tensor input instead of scalar Wn, rs, rps
torch_filtered_data_low = IIRFilter(
order,
torch.tensor(low_cutoff),
btype="low",
analog=False,
fs=sample_rate,
rs=torch.tensor(rs) if rs is not None else None,
rp=torch.tensor(rp) if rp is not None else None,
ftype=ftype,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)

rp, rs = rprs
b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=filter,
)
scipy_filtered_data_low = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=filter,
)
scipy_filtered_data_high = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

# test one of these with a tensor input instead of scalar Wn, rs, rps
torch_filtered_data_low = IIRFilter(
order,
torch.tensor(low_cutoff),
btype="low",
analog=False,
fs=sample_rate,
rs=torch.tensor(rs) if rs is not None else None,
rp=torch.tensor(rp) if rp is not None else None,
ftype=filter,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=filter,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
scipy_filtered_data_low,
torch_filtered_data_low[0],
atol=2e-1,
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
scipy_filtered_data_high,
torch_filtered_data_high[0],
atol=2e-1,
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)

assert np.allclose(
scipy_filtered_data_low,
torch_filtered_data_low[0],
atol=2e-1,
)
assert np.allclose(
scipy_filtered_data_high,
torch_filtered_data_high[0],
atol=2e-1,
)


N_SAMPLES = 1

Expand Down Expand Up @@ -201,6 +220,8 @@ def f_ref(request):
def test_filters_phenom_signal(
sample_rate,
order,
filter,
rprs,
chirp_mass,
mass_ratio,
distance,
Expand Down Expand Up @@ -266,84 +287,85 @@ def test_filters_phenom_signal(

slice_length = int(0.15 * sample_rate)

for ftype, (rp, rs) in zip(filters, rprs):
b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)

scipy_filtered_data_low = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_high = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

torch_filtered_data_low = IIRFilter(
order,
low_cutoff,
btype="low",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)
rp, rs = rprs

b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=filter,
)

scipy_filtered_data_low = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=filter,
)
scipy_filtered_data_high = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

torch_filtered_data_low = IIRFilter(
order,
low_cutoff,
btype="low",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=filter,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=filter,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
1e21 * scipy_filtered_data_low,
1e21 * torch_filtered_data_low[0],
atol=7e-3,
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
1e21 * scipy_filtered_data_high,
1e21 * torch_filtered_data_high[0],
atol=7e-3,
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)

assert np.allclose(
1e21 * scipy_filtered_data_low,
1e21 * torch_filtered_data_low[0],
atol=7e-3,
)
assert np.allclose(
1e21 * scipy_filtered_data_high,
1e21 * torch_filtered_data_high[0],
atol=7e-3,
)

0 comments on commit ddceca3

Please sign in to comment.