diff --git a/ml4gw/transforms/__init__.py b/ml4gw/transforms/__init__.py index 74a7eb8..e9cf480 100644 --- a/ml4gw/transforms/__init__.py +++ b/ml4gw/transforms/__init__.py @@ -1,3 +1,4 @@ +from .iirfilter import IIRFilter from .pearson import ShiftedPearsonCorrelation from .qtransform import QScan, SingleQTransform from .scaler import ChannelWiseScaler diff --git a/ml4gw/waveforms/generator.py b/ml4gw/waveforms/generator.py index bcf4197..d9e27c1 100644 --- a/ml4gw/waveforms/generator.py +++ b/ml4gw/waveforms/generator.py @@ -1,11 +1,13 @@ import math from typing import Callable, Dict, Tuple +import numpy as np import torch from jaxtyping import Float from torch import Tensor from ml4gw.constants import MSUN +from ml4gw.transforms import IIRFilter from ml4gw.types import BatchTensor from ml4gw.waveforms.cbc import utils @@ -68,10 +70,7 @@ def __init__( self.right_pad = right_pad self.f_ref = f_ref - def get_frequencies(self, df: float): - """Get the frequencies from 0 to nyquist for corresponding df""" - num_freqs = int(self.nyquist / df) + 1 - return torch.linspace(0, self.nyquist, num_freqs) + self.highpass = self.build_highpass_filter() @property def delta_t(self): @@ -90,6 +89,30 @@ def size(self): def delta_f(self): return 1 / self.duration + def build_highpass_filter(self): + """ + Builds highpass filter object. + + See https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/conditioning_subroutines.py?ref_type=heads#L10 # noqa + """ + order = 8.0 + w1 = np.tan(np.pi * (self.f_min) / self.sample_rate) + attenuation = 0.99 + wc = w1 * (1.0 / attenuation**0.5 - 1) ** (1.0 / (2.0 * order)) + fc = self.sample_rate * np.arctan(wc) / np.pi + + return IIRFilter( + order, + fc / (self.sample_rate / 2), + btype="highpass", + ftype="butterworth", + ) + + def get_frequencies(self, df: float): + """Get the frequencies from 0 to nyquist for corresponding df""" + num_freqs = int(self.nyquist / df) + 1 + return torch.linspace(0, self.nyquist, num_freqs) + def generate_conditioned_fd_waveform( self, **parameters: dict[str, BatchTensor] ): @@ -275,4 +298,9 @@ def forward( hc = torch.nn.functional.pad(hc, (pad, 0)) hp = torch.nn.functional.pad(hp, (pad, 0)) + # finally, highpass the waveforms, + # going to double precision + hp = self.highpass(hp.double()) + hc = self.highpass(hc.double()) + return hc, hp diff --git a/tests/waveforms/test_generator.py b/tests/waveforms/test_generator.py index 6ecccd8..5ff53ab 100644 --- a/tests/waveforms/test_generator.py +++ b/tests/waveforms/test_generator.py @@ -12,17 +12,7 @@ from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator -@pytest.fixture(params=[10, 100, 1000]) -def n_samples(request): - return request.param - - -@pytest.fixture(params=[1, 2, 10]) -def duration(request): - return request.param - - -@pytest.fixture(params=[1024, 2048, 4096]) +@pytest.fixture(params=[2048]) def sample_rate(request): return request.param @@ -84,7 +74,7 @@ def test_cbc_waveform_generator( theta_jn, sample_rate, ): - duration = 20 + duration = 10 f_min = 20 f_ref = 40 right_pad = 0.5 @@ -163,31 +153,23 @@ def test_cbc_waveform_generator( gwsignal_params, gwsignal_generator ) - hp_ml4gw_highpassed = high_pass_time_series( - hp_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0 - ) - hc_ml4gw_highpassed = high_pass_time_series( - hc_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0 - ) + hp = hp_ml4gw[i].detach().numpy() + hc = hc_ml4gw[i].detach().numpy() # now align the gwsignal and ml4gw waveforms so we can compoare - ml4gw_times = np.arange( - 0, len(hp_ml4gw_highpassed) / sample_rate, 1 / sample_rate - ) + ml4gw_times = np.arange(0, len(hp) / sample_rate, 1 / sample_rate) ml4gw_times -= duration - right_pad - hp_ml4gw_times = ( - ml4gw_times - ml4gw_times[np.argmax(hp_ml4gw_highpassed)] - ) + hp_ml4gw_times = ml4gw_times - ml4gw_times[np.argmax(hp)] max_time = min(hp_ml4gw_times[-1], hp_gwsignal.times.value[-1]) min_time = max(hp_ml4gw_times[0], hp_gwsignal.times.value[0]) - mask = hp_gwsignal.times.value < max_time - mask &= hp_gwsignal.times.value > min_time + mask = hp_gwsignal.times.value <= max_time + mask &= hp_gwsignal.times.value >= min_time - ml4gw_mask = hp_ml4gw_times < max_time - ml4gw_mask &= hp_ml4gw_times > min_time + ml4gw_mask = hp_ml4gw_times <= max_time + ml4gw_mask &= hp_ml4gw_times >= min_time # TODO: track this down @@ -199,30 +181,28 @@ def test_cbc_waveform_generator( close_hp = np.allclose( hp_gwsignal.value[mask], - hp_ml4gw_highpassed[ml4gw_mask], + hp[ml4gw_mask], atol=5e-23, - rtol=0.01, + rtol=0.05, ) - close_hp = close_hp or np.allclose( + assert close_hp or np.allclose( hp_gwsignal.value[mask][:-1], - hp_ml4gw_highpassed[ml4gw_mask][1:], + hp[ml4gw_mask][1:], atol=5e-23, - rtol=0.01, + rtol=0.05, ) - assert close_hp close_hc = np.allclose( hc_gwsignal.value[mask], - hc_ml4gw_highpassed[ml4gw_mask], + hc[ml4gw_mask], atol=5e-23, - rtol=0.01, + rtol=0.05, ) - close_hc = close_hc or np.allclose( + assert close_hc or np.allclose( hc_gwsignal.value[mask][:-1], - hc_ml4gw_highpassed[ml4gw_mask][1:], + hc[ml4gw_mask][1:], atol=5e-23, - rtol=0.01, + rtol=0.05, ) - assert close_hc