Skip to content

Commit

Permalink
tests for other filters (cheby1, cheby2, ellip, bessel)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravioli1369 committed Jan 27, 2025
1 parent 5e67ff4 commit 9106378
Showing 1 changed file with 157 additions and 129 deletions.
286 changes: 157 additions & 129 deletions tests/transforms/test_iirfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from astropy import units as u
from scipy.signal import butter, filtfilt
from scipy.signal import filtfilt, iirfilter
from torch.distributions import Uniform

from ml4gw.constants import MSUN
Expand All @@ -16,6 +16,8 @@

low_cutoff = 100
high_cutoff = 20
filters = ["cheby1", "cheby2", "ellip", "bessel", "butter"]
rprs = [(0.5, None), (None, 20), (0.5, 20), (None, None), (None, None)]


@pytest.fixture(params=[256, 512, 1024, 2048])
Expand All @@ -28,7 +30,7 @@ def order(request):
return request.param


def test_butterworth_synthetic_signal(sample_rate, order):
def test_filters_synthetic_signal(sample_rate, order):
t = np.linspace(0, 1.0, sample_rate, endpoint=False)
tone_freq = 50
noise_amplitude = 0.5
Expand All @@ -37,77 +39,90 @@ def test_butterworth_synthetic_signal(sample_rate, order):
noise = noise_amplitude * np.random.normal(size=t.shape)
combined_signal = signal + noise

slice_length = int(0.1 * sample_rate)

butterworth_low = IIRFilter(
order,
low_cutoff,
btype="low",
analog=False,
fs=sample_rate,
)
butterworth_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
)
slice_length = int(0.15 * sample_rate)

for ftype, (rp, rs) in zip(filters, rprs):
b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_low = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_high = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

# test one of these with a tensor input instead of scalar Wn, rs, rps
torch_filtered_data_low = IIRFilter(
order,
torch.tensor(low_cutoff),
btype="low",
analog=False,
fs=sample_rate,
rs=torch.tensor(rs) if rs is not None else None,
rp=torch.tensor(rp) if rp is not None else None,
ftype=ftype,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(combined_signal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)

b, a = butter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
)
scipy_filtered_data_low = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]
b, a = butter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
)
scipy_filtered_data_high = filtfilt(b, a, combined_signal)[
slice_length:-slice_length
]

torch_filtered_data_low = butterworth_low(
torch.tensor(combined_signal).repeat(10, 1)
)[:, slice_length:-slice_length].numpy()
torch_filtered_data_high = butterworth_high(
torch.tensor(combined_signal).repeat(10, 1)
)[:, slice_length:-slice_length].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
scipy_filtered_data_low,
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
atol=2e-1,
)
assert np.allclose(
scipy_filtered_data_high,
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
atol=2e-1,
)

assert np.allclose(
scipy_filtered_data_low,
torch_filtered_data_low[0],
atol=1e-1,
)
assert np.allclose(
scipy_filtered_data_high,
torch_filtered_data_high[0],
atol=1e-1,
)


N_SAMPLES = 1

Expand Down Expand Up @@ -249,73 +264,86 @@ def test_butterworth_phenom_signal(
hp_lal, _ = lalsimulation.SimInspiralChooseFDWaveform(**params)
hp_lal = hp_lal.data.data.real

slice_length = int(0.1 * sample_rate)

b, a = butter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
)
scipy_filtered_data_low = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]
b, a = butter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
)
scipy_filtered_data_high = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

butterworth_low = IIRFilter(
order,
low_cutoff,
btype="low",
analog=False,
fs=sample_rate,
)
butterworth_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
)
slice_length = int(0.15 * sample_rate)

for ftype, (rp, rs) in zip(filters, rprs):
b, a = iirfilter(
order,
low_cutoff,
btype="low",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)

torch_filtered_data_low = butterworth_low(
torch.tensor(hp_lal).repeat(10, 1)
)[:, slice_length:-slice_length].numpy()
torch_filtered_data_high = butterworth_high(
torch.tensor(hp_lal).repeat(10, 1)
)[:, slice_length:-slice_length].numpy()
scipy_filtered_data_low = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

b, a = iirfilter(
order,
high_cutoff,
btype="high",
analog=False,
output="ba",
fs=sample_rate,
rp=rp,
rs=rs,
ftype=ftype,
)
scipy_filtered_data_high = filtfilt(b, a, hp_lal)[
slice_length:-slice_length
]

torch_filtered_data_low = IIRFilter(
order,
low_cutoff,
btype="low",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

torch_filtered_data_high = IIRFilter(
order,
high_cutoff,
btype="high",
analog=False,
fs=sample_rate,
rs=rs,
rp=rp,
ftype=ftype,
)(torch.tensor(hp_lal).repeat(10, 1))[
:, slice_length:-slice_length
].numpy()

# test batch processing
for i in range(9):
assert np.allclose(
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
)
assert np.allclose(
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
)

# test batch processing
for i in range(9):
assert np.allclose(
torch_filtered_data_low[0],
torch_filtered_data_low[i + 1],
atol=float(np.finfo(float).eps),
1e21 * scipy_filtered_data_low,
1e21 * torch_filtered_data_low[0],
atol=2e-3,
)
assert np.allclose(
torch_filtered_data_high[0],
torch_filtered_data_high[i + 1],
atol=float(np.finfo(float).eps),
1e21 * scipy_filtered_data_high,
1e21 * torch_filtered_data_high[0],
atol=2e-3,
)

assert np.allclose(
1e21 * scipy_filtered_data_low,
1e21 * torch_filtered_data_low[0],
atol=1e-6,
)
assert np.allclose(
1e21 * scipy_filtered_data_high,
1e21 * torch_filtered_data_high[0],
atol=1e-6,
)

0 comments on commit 9106378

Please sign in to comment.