From ff058308ea5387b3f995a9b88f1b00035024edc1 Mon Sep 17 00:00:00 2001 From: "ethan.marx" Date: Fri, 31 Jan 2025 10:38:19 -0800 Subject: [PATCH] add tests for waveform generator --- tests/waveforms/test_generator.py | 162 ++++++++++++++++++++++-------- 1 file changed, 122 insertions(+), 40 deletions(-) diff --git a/tests/waveforms/test_generator.py b/tests/waveforms/test_generator.py index 4a78224f..f296bdb1 100644 --- a/tests/waveforms/test_generator.py +++ b/tests/waveforms/test_generator.py @@ -1,8 +1,12 @@ import astropy.units as u -import lal -import lalsimulation +import numpy as np import pytest import torch +from lalsimulation.gwsignal.core.waveform import ( + GenerateTDWaveform, + LALCompactBinaryCoalescenceGenerator, +) +from scipy.signal import butter, sosfiltfilt from ml4gw.waveforms import IMRPhenomD, conversion from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator @@ -23,6 +27,53 @@ def sample_rate(request): return request.param +def high_pass_time_series(time_series, dt, fmin, attenuation, N): + """ + Same as + https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/conditioning_subroutines.py?ref_type=heads#L10 # noqa + except w/o requiring gwpy.TimeSeries objects to be passed + """ + fs = 1.0 / dt # Sampling frequency + a1 = attenuation # Attenuation at the low-freq cut-off + + w1 = np.tan(np.pi * fmin * dt) # Transformed frequency variable at f_min + wc = w1 * (1.0 / a1**0.5 - 1) ** ( + 1.0 / (2.0 * N) + ) # Cut-off freq. from attenuation + fc = fs * np.arctan(wc) / np.pi # For use in butterworth filter + + # Construct the filter and then forward - backward filter the time-series + sos = butter(N, fc, btype="highpass", output="sos", fs=fs) + output = sosfiltfilt(sos, time_series) + return output + + +class GWsignalGenerator(LALCompactBinaryCoalescenceGenerator): + """ + Override gwsignal class to enforce use of `gwsignal` conditioning routines, + and to enforce that the frequency domain version of the approximant is used + for generating the waveform. + """ + + @property + def metadata(self): + metadata = { + "type": "cbc_lalsimulation", + "f_ref_spin": True, + "modes": True, + "polarizations": True, + "implemented_domain": self._implemented_domain, + "generation_domain": self._generation_domain, + "approximant": self._approx_name, + "implementation": "LALSimulation", + "conditioning_routines": "gwsignal", + } + return metadata + + def _update_domains(self): + self._implemented_domain = "freq" + + def test_cbc_waveform_generator( chirp_mass, mass_ratio, @@ -33,11 +84,11 @@ def test_cbc_waveform_generator( theta_jn, sample_rate, ): - sample_rate = 4096 - duration = 1 + sample_rate = 2048 + duration = 10 f_min = 20 f_ref = 40 - right_pad = 0.1 + right_pad = 0.5 generator = TimeDomainCBCWaveformGenerator( approximant=IMRPhenomD(), @@ -57,7 +108,7 @@ def test_cbc_waveform_generator( s2x = torch.zeros_like(chi2) s2y = torch.zeros_like(chi2) s2z = chi2 - parameters = { + ml4gw_parameters = { "chirp_mass": chirp_mass, "mass_ratio": mass_ratio, "mass_1": mass_1, @@ -74,49 +125,80 @@ def test_cbc_waveform_generator( "distance": distance, "inclination": theta_jn, } - hc, hp = generator(**parameters) + hc_ml4gw, hp_ml4gw = generator(**ml4gw_parameters) # now compare each waveform with lalsimulation SimInspiralTD for i in range(len(chirp_mass)): # construct lalinference params - params = dict( - m1=mass_1[i].item() * lal.MSUN_SI, - m2=mass_2[i].item() * lal.MSUN_SI, - S1x=s1x[i].item(), - S1y=s2y[i].item(), - S1z=s1z[i].item(), - S2x=s2x[i].item(), - S2y=s2y[i].item(), - S2z=s1z[i].item(), - distance=(distance[i].item() * u.Mpc).to("m").value, - inclination=theta_jn[i].item(), - phiRef=phase[i].item(), - longAscNodes=0.0, - eccentricity=0.0, - meanPerAno=0.0, - deltaT=1 / sample_rate, - f_min=f_min, - f_ref=f_ref, - approximant=lalsimulation.IMRPhenomD, - LALparams=lal.CreateDict(), + gwsignal_params = { + "mass1": ml4gw_parameters["mass_1"][i].item() * u.solMass, + "mass2": ml4gw_parameters["mass_2"][i].item() * u.solMass, + "deltaT": 1 / sample_rate * u.s, + "f22_start": f_min * u.Hz, + "f22_ref": f_ref * u.Hz, + "phi_ref": ml4gw_parameters["phic"][i].item() * u.rad, + "distance": (ml4gw_parameters["distance"][i].item() * u.Mpc), + "inclination": ml4gw_parameters["inclination"][i].item() * u.rad, + "eccentricity": 0.0 * u.dimensionless_unscaled, + "longAscNodes": 0.0 * u.rad, + "meanPerAno": 0.0 * u.rad, + "condition": 1, + "spin1x": ml4gw_parameters["s1x"][i].item() + * u.dimensionless_unscaled, + "spin1y": ml4gw_parameters["s1y"][i].item() + * u.dimensionless_unscaled, + "spin1z": ml4gw_parameters["s1z"][i].item() + * u.dimensionless_unscaled, + "spin2x": ml4gw_parameters["s2x"][i].item() + * u.dimensionless_unscaled, + "spin2y": ml4gw_parameters["s2y"][i].item() + * u.dimensionless_unscaled, + "spin2z": ml4gw_parameters["s2z"][i].item() + * u.dimensionless_unscaled, + "condition": 1, + } + + gwsignal_generator = GWsignalGenerator("IMRPhenomD") + hp_gwsignal, hc_gwsignal = GenerateTDWaveform( + gwsignal_params, gwsignal_generator ) - hp_lal, hc_lal = lalsimulation.SimInspiralTD(**params) + hp_ml4gw_highpassed = high_pass_time_series( + hp_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0 + ) + hc_ml4gw_highpassed = high_pass_time_series( + hc_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0 + ) - # compare the two waveforms - # first center the lal waveforms - lal_times = ( - torch.arange(hp_lal.data.length) * hp_lal.deltaT + hp_lal.epoch + # now align the gwsignal and ml4gw waveforms so we can compoare + ml4gw_times = np.arange( + 0, len(hp_ml4gw_highpassed) / sample_rate, 1 / sample_rate ) - mask = lal_times >= -duration & lal_times <= right_pad + ml4gw_times -= duration - right_pad - hp_lal = hp_lal.data.data[mask] - hc_lal = hc_lal.data.data[mask] + hp_ml4gw_times = ( + ml4gw_times - ml4gw_times[np.argmax(hp_ml4gw_highpassed)] + ) - size = int(duration + right_pad) * sample_rate - hc_ml4gw = hc[i].detach().numpy()[-size:] - hp_ml4gw = hp[i].detach().numpy()[-size:] + max_time = min(hp_ml4gw_times[-1], hp_gwsignal.times.value[-1]) + min_time = max(hp_ml4gw_times[0], hp_gwsignal.times.value[0]) - assert torch.allclose(hc_ml4gw, hc_lal, atol=1e-22) - assert torch.allclose(hp_ml4gw, hp_lal, atol=1e-22) + mask = hp_gwsignal.times.value < max_time + mask &= hp_gwsignal.times.value > min_time + + ml4gw_mask = hp_ml4gw_times < max_time + ml4gw_mask &= hp_ml4gw_times > min_time + + assert np.allclose( + hp_gwsignal.value[mask], + hp_ml4gw_highpassed[ml4gw_mask], + atol=3e-23, + rtol=0.01, + ) + assert np.allclose( + hc_gwsignal.value[mask], + hc_ml4gw_highpassed[ml4gw_mask], + atol=3e-23, + rtol=0.01, + )