Skip to content

Commit

Permalink
fix issues found by unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Collin Capano committed Dec 18, 2024
1 parent 2d0ad17 commit 279eedf
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 104 deletions.
69 changes: 35 additions & 34 deletions pycbc/inference/models/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,40 @@
fd_data_from_strain_dict, gate_overwhitened_data)


def catch_waveform_error(method):
"""Decorator that will catch no waveform errors.
This can be added to a method in an inference model. The decorator will
call the model's `_nowaveform_return` method if either of the following
happens when the wrapped method is executed:
* A `NoWaveformError` is raised.
* A `RuntimeError` or `FailedWaveformError` is raised and the model has
an `ignore_failed_waveforms` attribute that is set to True.
This requires the model to have a `_nowaveform_handler` method.
"""
# the functools.wroaps decorator preserves the original method's name
# and docstring
@wraps(method)
def method_wrapper(self, *args, **kwargs):
try:
retval = method(self, *args, **kwargs)
except NoWaveformError:
retval = self._nowaveform_handler()
except (RuntimeError, FailedWaveformError) as e:
try:
ignore_failed = self.ignore_failed_waveforms
except AttributeError:
ignore_failed = False
if ignore_failed:
retval = self._nowaveform_handler()
else:
raise e
return retval
return method_wrapper


class BaseGaussianNoise(BaseDataModel, metaclass=ABCMeta):
r"""Model for analyzing GW data with assuming a wide-sense stationary
Gaussian noise model.
Expand Down Expand Up @@ -493,7 +527,7 @@ def _nowaveform_handler(self):
"""
raise NotImplementedError(
f"A waveform could not be generated, but this model does not know "
f"how to handle that. The parameters were: {self.current_params}".)
f"how to handle that. The parameters were: {self.current_params}.")

@classmethod
def from_config(cls, cp, data_section='data', data=None, psds=None,
Expand Down Expand Up @@ -1232,36 +1266,3 @@ def create_waveform_generator(
recalib=recalibration, gates=gates,
**static_params)
return waveform_generator


def catch_waveform_error(method):
"""Decorator that will catch no waveform errors.
This can be added to a method in an inference model. The decorator will
call the model's `_nowaveform_return` method if either of the following
happens when the wrapped method is executed:
* A `NoWaveformError` is raised.
* A `RuntimeError` or `FailedWaveformError` is raised and the model has
an `ignore_failed_waveforms` attribute that is set to True.
This requires the model to have a `_nowaveform_return` method.
"""
# the functools.wroaps decorator preserves the original method's name
# and docstring
@wraps(method)
def method_wrapper(self, *args, **kwargs):
try:
retval = method(self, *args, **kwargs)
except NoWaveformError:
retval = self._nowaveform_return()
except (RuntimeError, FailedWaveformError) as e:
try:
ignore_failed = self.ignore_failed_waveforms
except AttributeError:
ignore_failed = False
if ignore_failed:
retval = self._nowaveform_return()
raise e
return retval
return method_wrapper
147 changes: 77 additions & 70 deletions test/test_infmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,84 +178,93 @@ def setUpClass(cls):
cls.psds = {}
cls.data = {}
# static params for the test
tc = 1187008882.42840
flow = 20
cls.static = {
'f_final_func': 'SchwarzISCO',
'approximant':"TaylorF2",
'approximant':"IMRPhenomD",
'mass1': 40.,
'mass2': 40.,
'polarization': 0,
'ra': 3.44615914,
'dec': -0.40808407,
'tc': 1187008882.42840,
'distance': 42.,
'tc': tc,
'distance': 100.,
'inclination': 2.5
}
cls.variable = ['mass1', 'mass2', 'f_lower']
cls.variable = ['spin1z', 'f_lower']
ifos = ['H1', 'L1', 'V1']
# generate the reference psd
flow = 20
seglen = 256
seglen = 8
delta_f = 1./seglen
sample_rate = 4096
delta_t = 1./sample_rate
flen = int(sample_rate * seglen / 2) + 1
psd = aLIGOZeroDetHighPower(flen, delta_f, flow)
# put non-zero values in the beginning and end of the psd
# so the gating models will work
psd[0:int(flow/delta_f+1)] = psd[int(flow/delta_f+1)]
psd[-2:] = psd[-2]
seed = 1000
cls.flow = {'H1': flow, 'L1': flow, 'V1': flow}
# generate the noise
for ifo in ifos:
print("Generating {} noise".format(ifo))

tsamples = int(seglen * sample_rate)
ts = noise_from_psd(tsamples, delta_t, psd, seed=seed)
ts._epoch = cls.static['tc'] - seglen/2
seed += 1027
cls.data[ifo] = ts
cls.data[ifo] = ts.to_frequencyseries()
cls.psds[ifo] = psd
# setup priors
inclination_prior = SinAngle(inclination=None)
distance_prior = Uniform(distance=(10, 100))
tc_prior = Uniform(tc=(m.time-0.1, m.time+0.1))
spin_prior = Uniform(spin1z=(-1., 2.))
flowbad = 4000.
flower_prior = Uniform(f_lower=(flow, flowbad+100.))
pol = UniformAngle(polarization=None)
cls.prior = JointDistribution(cls.variable, inclination_prior,
distance_prior)
cls.prior = JointDistribution(cls.variable, spin_prior, flower_prior)

# set up for marginalized polarization tests
cls.static2 = cls.static.copy()
cls.static2.pop('polarization')
cls.variable2 = cls.variable + ['polarization']
cls.prior2 = JointDistribution(cls.variable2, inclination_prior,
distance_prior, pol)
cls.prior2 = JointDistribution(cls.variable2, spin_prior, flower_prior,
pol)
# set up gated parameters
staticgate = cls.static.copy()
staticgate['t_gate_start'] = static['tc'] - 1.
staticgate['t_gate_end'] = static['tc']
self.staticgate = staticgate
staticgate['t_gate_start'] = tc - 0.05
staticgate['t_gate_end'] = tc
cls.staticgate = staticgate
# margpol
staticgate2 = cls.static2.copy()
staticgate2['t_gate_start'] = static2['tc'] - 1.
staticgate2['t_gate_end'] = static2['tc']
self.staticgate2 = staticgate2
staticgate2['t_gate_start'] = tc - 0.05
staticgate2['t_gate_end'] = tc
cls.staticgate2 = staticgate2
# the parameters to test:
# these parameters should pass
cls.pass_params = {'mass1': 1.4, 'mass2': 1.4, 'f_lower': flow}
cls.pass_params = {'spin1z': 0., 'f_lower': flow}
# these parameters should trigger a NoWaveformError
cls.nowf_params = {'mass1': 1.4, 'mass2': 1.4, 'f_lower': 4000.}
cls.nowf_params = {'spin1z': 0., 'f_lower': flowbad}
# these parameters should cause a FailedWaveformError
cls.fail_params = {'mass1': -1.4, 'mass2': 1.4, 'f_lower': flow}

def _run_tests(self, model):
# first check that the model works
model.update(**self.pass_params)
self.assertTrue(numpy.isfinite(model.loglr))
# now check that a no waveform error is caught correctly
model.update(**self.nowf_params)
self.assertEqual(model.loglr == -numpy.inf)
# now check that a failed waveform is caught correctly
model.update(**self.fail_params)
self.assertEqual(model.loglr == -numpy.inf)
# now check that an error is raised if ignore_failed_waveforms is False
model.ignore_failed_waveforms = False
model.update(**self.fail_params)
model.assertRaises((FailedWaveformError, RuntimeError), model.loglr)
cls.fail_params = {'spin1z': 2., 'f_lower': flow}

def _run_tests(self, model, check_pass=True, check_nowf=True,
check_failed=True, check_raises=True):
# check that the model works
if check_pass:
model.update(**self.pass_params)
self.assertTrue(numpy.isfinite(model.loglr))
# check that a no waveform error is caught correctly
if check_nowf:
model.update(**self.nowf_params)
self.assertEqual(model.loglr, -numpy.inf)
# check that a failed waveform is caught correctly
if check_failed:
model.update(**self.fail_params)
self.assertEqual(model.loglr, -numpy.inf)
# check that an error is raised if ignore_failed_waveforms is False
if check_raises:
model.ignore_failed_waveforms = False
model.update(**self.fail_params)
with self.assertRaises((FailedWaveformError, RuntimeError)):
model.loglr

def test_base_phase_marg(self):
model = models.MarginalizedPhaseGaussianNoise(
Expand All @@ -273,64 +282,62 @@ def test_relative_phase_marg(self):
psds = self.psds,
static_params = self.static,
prior = self.prior,
fiducial_params = {'mass1':1.3756},
fiducial_params = {},
#fiducial_params = {'mass1': 40.},
epsilon = .1,
)
self._run_tests(model)
ignore_failed_waveforms=True)
# relative model doesn't respect flower, so no point in testing nowf
self._run_tests(model, check_nowf=False)

def test_single_phase_marg(self):
model = models.SingleTemplate(
self.variable, copy.deepcopy(self.data),
low_frequency_cutoff=self.flow,
psds = self.psds,
static_params = self.static,
prior = self.prior,
)
self._run_tests(model)

def test_single_pol_phase_marg(self):
model = models.SingleTemplate(
self.variable2, copy.deepcopy(self.data),
def test_brute_pol_phase_marg(self):
# Uses the old polarization syntax untill we decide to remove it.
# Untill then, this also tests that that interface stays working.
model = models.BruteParallelGaussianMarginalize(
self.variable, data=copy.deepcopy(self.data),
low_frequency_cutoff=self.flow,
psds = self.psds,
static_params = self.static2,
prior = self.prior2,
marginalize_vector_samples = 1000,
marginalize_vector_params = 'polarization',
prior = self.prior,
marginalize_phase=4,
cores=1,
base_model='marginalized_polarization',
ignore_failed_waveforms=True
)
self._run_tests(model)

def test_brute_pol_phase_marg(self):
# Uses the old polarization syntax untill we decide to remove it.
# Untill then, this also tests that that interface stays working.
# we need to do the check raises test separately because the underlying
# base model's ignore_failed_waveforms needs to be set
self._run_tests(model, check_raises=False)
model = models.BruteParallelGaussianMarginalize(
self.variable, data=copy.deepcopy(self.data),
low_frequency_cutoff=self.flow,
psds = self.psds,
static_params = self.static2,
prior = self.prior,
marginalize_phase=400,
marginalize_phase=4,
cores=1,
base_model='marginalized_polarization',
ignore_failed_waveforms=False
)
self._run_tests(model)
self._run_tests(model, check_pass=False, check_nowf=False,
check_failed=False, check_raises=True)

def test_gated_gaussian_noise(self):
model = models.GatedGaussianNoise(
self.variable, data=copy.deepcopy(self.data),
low_frequency_cutoff=self.flow,
psds=self.psds,
static_params=self.staticgate,
prior=self.prior)
prior=self.prior,
ignore_failed_waveforms=True)
self._run_tests(model)

def test_gated_gaussian_margpol(self):
model = models.GatedGaussianNoise(
model = models.GatedGaussianMargPol(
self.variable, data=copy.deepcopy(self.data),
low_frequency_cutoff=self.flow,
psds=self.psds,
static_params=self.staticgate2,
prior=self.prior)
prior=self.prior,
ignore_failed_waveforms=True)
self._run_tests(model)


Expand Down

0 comments on commit 279eedf

Please sign in to comment.