Skip to content

Commit

Permalink
add tests for waveform generator
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanMarx committed Jan 31, 2025
1 parent fb8b47d commit ff05830
Showing 1 changed file with 122 additions and 40 deletions.
162 changes: 122 additions & 40 deletions tests/waveforms/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import astropy.units as u
import lal
import lalsimulation
import numpy as np
import pytest
import torch
from lalsimulation.gwsignal.core.waveform import (
GenerateTDWaveform,
LALCompactBinaryCoalescenceGenerator,
)
from scipy.signal import butter, sosfiltfilt

from ml4gw.waveforms import IMRPhenomD, conversion
from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator
Expand All @@ -23,6 +27,53 @@ def sample_rate(request):
return request.param


def high_pass_time_series(time_series, dt, fmin, attenuation, N):
"""
Same as
https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/python/lalsimulation/gwsignal/core/conditioning_subroutines.py?ref_type=heads#L10 # noqa
except w/o requiring gwpy.TimeSeries objects to be passed
"""
fs = 1.0 / dt # Sampling frequency
a1 = attenuation # Attenuation at the low-freq cut-off

w1 = np.tan(np.pi * fmin * dt) # Transformed frequency variable at f_min
wc = w1 * (1.0 / a1**0.5 - 1) ** (
1.0 / (2.0 * N)
) # Cut-off freq. from attenuation
fc = fs * np.arctan(wc) / np.pi # For use in butterworth filter

# Construct the filter and then forward - backward filter the time-series
sos = butter(N, fc, btype="highpass", output="sos", fs=fs)
output = sosfiltfilt(sos, time_series)
return output


class GWsignalGenerator(LALCompactBinaryCoalescenceGenerator):
"""
Override gwsignal class to enforce use of `gwsignal` conditioning routines,
and to enforce that the frequency domain version of the approximant is used
for generating the waveform.
"""

@property
def metadata(self):
metadata = {
"type": "cbc_lalsimulation",
"f_ref_spin": True,
"modes": True,
"polarizations": True,
"implemented_domain": self._implemented_domain,
"generation_domain": self._generation_domain,
"approximant": self._approx_name,
"implementation": "LALSimulation",
"conditioning_routines": "gwsignal",
}
return metadata

def _update_domains(self):
self._implemented_domain = "freq"


def test_cbc_waveform_generator(
chirp_mass,
mass_ratio,
Expand All @@ -33,11 +84,11 @@ def test_cbc_waveform_generator(
theta_jn,
sample_rate,
):
sample_rate = 4096
duration = 1
sample_rate = 2048
duration = 10
f_min = 20
f_ref = 40
right_pad = 0.1
right_pad = 0.5

generator = TimeDomainCBCWaveformGenerator(
approximant=IMRPhenomD(),
Expand All @@ -57,7 +108,7 @@ def test_cbc_waveform_generator(
s2x = torch.zeros_like(chi2)
s2y = torch.zeros_like(chi2)
s2z = chi2
parameters = {
ml4gw_parameters = {
"chirp_mass": chirp_mass,
"mass_ratio": mass_ratio,
"mass_1": mass_1,
Expand All @@ -74,49 +125,80 @@ def test_cbc_waveform_generator(
"distance": distance,
"inclination": theta_jn,
}
hc, hp = generator(**parameters)
hc_ml4gw, hp_ml4gw = generator(**ml4gw_parameters)

# now compare each waveform with lalsimulation SimInspiralTD
for i in range(len(chirp_mass)):

# construct lalinference params
params = dict(
m1=mass_1[i].item() * lal.MSUN_SI,
m2=mass_2[i].item() * lal.MSUN_SI,
S1x=s1x[i].item(),
S1y=s2y[i].item(),
S1z=s1z[i].item(),
S2x=s2x[i].item(),
S2y=s2y[i].item(),
S2z=s1z[i].item(),
distance=(distance[i].item() * u.Mpc).to("m").value,
inclination=theta_jn[i].item(),
phiRef=phase[i].item(),
longAscNodes=0.0,
eccentricity=0.0,
meanPerAno=0.0,
deltaT=1 / sample_rate,
f_min=f_min,
f_ref=f_ref,
approximant=lalsimulation.IMRPhenomD,
LALparams=lal.CreateDict(),
gwsignal_params = {
"mass1": ml4gw_parameters["mass_1"][i].item() * u.solMass,
"mass2": ml4gw_parameters["mass_2"][i].item() * u.solMass,
"deltaT": 1 / sample_rate * u.s,
"f22_start": f_min * u.Hz,
"f22_ref": f_ref * u.Hz,
"phi_ref": ml4gw_parameters["phic"][i].item() * u.rad,
"distance": (ml4gw_parameters["distance"][i].item() * u.Mpc),
"inclination": ml4gw_parameters["inclination"][i].item() * u.rad,
"eccentricity": 0.0 * u.dimensionless_unscaled,
"longAscNodes": 0.0 * u.rad,
"meanPerAno": 0.0 * u.rad,
"condition": 1,
"spin1x": ml4gw_parameters["s1x"][i].item()
* u.dimensionless_unscaled,
"spin1y": ml4gw_parameters["s1y"][i].item()
* u.dimensionless_unscaled,
"spin1z": ml4gw_parameters["s1z"][i].item()
* u.dimensionless_unscaled,
"spin2x": ml4gw_parameters["s2x"][i].item()
* u.dimensionless_unscaled,
"spin2y": ml4gw_parameters["s2y"][i].item()
* u.dimensionless_unscaled,
"spin2z": ml4gw_parameters["s2z"][i].item()
* u.dimensionless_unscaled,
"condition": 1,
}

gwsignal_generator = GWsignalGenerator("IMRPhenomD")
hp_gwsignal, hc_gwsignal = GenerateTDWaveform(
gwsignal_params, gwsignal_generator
)

hp_lal, hc_lal = lalsimulation.SimInspiralTD(**params)
hp_ml4gw_highpassed = high_pass_time_series(
hp_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0
)
hc_ml4gw_highpassed = high_pass_time_series(
hc_ml4gw[i].detach().numpy(), 1 / sample_rate, f_min, 0.99, 8.0
)

# compare the two waveforms
# first center the lal waveforms
lal_times = (
torch.arange(hp_lal.data.length) * hp_lal.deltaT + hp_lal.epoch
# now align the gwsignal and ml4gw waveforms so we can compoare
ml4gw_times = np.arange(
0, len(hp_ml4gw_highpassed) / sample_rate, 1 / sample_rate
)
mask = lal_times >= -duration & lal_times <= right_pad
ml4gw_times -= duration - right_pad

hp_lal = hp_lal.data.data[mask]
hc_lal = hc_lal.data.data[mask]
hp_ml4gw_times = (
ml4gw_times - ml4gw_times[np.argmax(hp_ml4gw_highpassed)]
)

size = int(duration + right_pad) * sample_rate
hc_ml4gw = hc[i].detach().numpy()[-size:]
hp_ml4gw = hp[i].detach().numpy()[-size:]
max_time = min(hp_ml4gw_times[-1], hp_gwsignal.times.value[-1])
min_time = max(hp_ml4gw_times[0], hp_gwsignal.times.value[0])

assert torch.allclose(hc_ml4gw, hc_lal, atol=1e-22)
assert torch.allclose(hp_ml4gw, hp_lal, atol=1e-22)
mask = hp_gwsignal.times.value < max_time
mask &= hp_gwsignal.times.value > min_time

ml4gw_mask = hp_ml4gw_times < max_time
ml4gw_mask &= hp_ml4gw_times > min_time

assert np.allclose(
hp_gwsignal.value[mask],
hp_ml4gw_highpassed[ml4gw_mask],
atol=3e-23,
rtol=0.01,
)
assert np.allclose(
hc_gwsignal.value[mask],
hc_ml4gw_highpassed[ml4gw_mask],
atol=3e-23,
rtol=0.01,
)

0 comments on commit ff05830

Please sign in to comment.