From 242620a58d9f3a10e8b6cb34dcfd47212caa5d85 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Wed, 18 Dec 2024 12:01:46 -0500 Subject: [PATCH] Add waveform error handling to Relbin (and other models) (#4996) * catch failed waveforms by using a decorator * add ability to catch failed waveforms to relbin models * add RuntimeError to list of errors caught * add unittests * fix issues found by unit tests * raise an error in Relative model if return_sh_hh set to True --------- Co-authored-by: Collin Capano --- .../inference/models/gated_gaussian_noise.py | 37 +--- pycbc/inference/models/gaussian_noise.py | 62 ++++-- .../models/marginalized_gaussian_noise.py | 87 +++------ pycbc/inference/models/relbin.py | 32 +++- test/test_infmodel.py | 178 +++++++++++++++++- 5 files changed, 296 insertions(+), 100 deletions(-) diff --git a/pycbc/inference/models/gated_gaussian_noise.py b/pycbc/inference/models/gated_gaussian_noise.py index 33e7ecc27e1..a2d6128cec9 100644 --- a/pycbc/inference/models/gated_gaussian_noise.py +++ b/pycbc/inference/models/gated_gaussian_noise.py @@ -23,14 +23,14 @@ import numpy from scipy import special -from pycbc.waveform import (NoWaveformError, FailedWaveformError) from pycbc.types import FrequencySeries from pycbc.detector import Detector from pycbc.pnutils import hybrid_meco_frequency from pycbc.waveform.utils import time_from_frequencyseries from pycbc.waveform import generator from pycbc.filter import highpass -from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator) +from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator, + catch_waveform_error) from .base_data import BaseDataModel from .data_utils import fd_data_from_strain_dict @@ -134,8 +134,7 @@ def normalize(self, normalize): """ self._normalize = normalize - @staticmethod - def _nowaveform_logl(): + def _nowaveform_handler(self): """Convenience function to set logl values if no waveform generated. """ return -numpy.inf @@ -329,20 +328,14 @@ def get_gate_times(self): def get_gate_times_hmeco(self): """Gets the time to apply a gate based on the current sky position. + Returns ------- dict : Dictionary of detector names -> (gate start, gate width) """ # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e + wfs = self.get_waveforms() # get waveform parameters params = self.current_params spin1 = params['spin1z'] @@ -514,6 +507,7 @@ def _extra_stats(self): """No extra stats are stored.""" return [] + @catch_waveform_error def _loglikelihood(self): r"""Computes the log likelihood after removing the power within the given time window, @@ -530,14 +524,7 @@ def _loglikelihood(self): The value of the log likelihood. """ # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e + wfs = self.get_waveforms() # get the times of the gates gate_times = self.get_gate_times() logl = 0. @@ -681,6 +668,7 @@ def _extra_stats(self): """Adds the maxL polarization and corresponding likelihood.""" return ['maxl_polarization', 'maxl_logl'] + @catch_waveform_error def _loglikelihood(self): r"""Computes the log likelihood after removing the power within the given time window, @@ -697,14 +685,7 @@ def _loglikelihood(self): The value of the log likelihood. """ # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e + wfs = self.get_waveforms() # get the gated waveforms and data gated_wfs = self.get_gated_waveforms() gated_data = self.get_gated_data() diff --git a/pycbc/inference/models/gaussian_noise.py b/pycbc/inference/models/gaussian_noise.py index 5844f58b956..1442d7ae7e9 100644 --- a/pycbc/inference/models/gaussian_noise.py +++ b/pycbc/inference/models/gaussian_noise.py @@ -19,6 +19,7 @@ import logging import shlex from abc import ABCMeta +from functools import wraps import numpy from pycbc import filter as pyfilter @@ -37,6 +38,40 @@ fd_data_from_strain_dict, gate_overwhitened_data) +def catch_waveform_error(method): + """Decorator that will catch no waveform errors. + + This can be added to a method in an inference model. The decorator will + call the model's `_nowaveform_return` method if either of the following + happens when the wrapped method is executed: + + * A `NoWaveformError` is raised. + * A `RuntimeError` or `FailedWaveformError` is raised and the model has + an `ignore_failed_waveforms` attribute that is set to True. + + This requires the model to have a `_nowaveform_handler` method. + """ + # the functools.wroaps decorator preserves the original method's name + # and docstring + @wraps(method) + def method_wrapper(self, *args, **kwargs): + try: + retval = method(self, *args, **kwargs) + except NoWaveformError: + retval = self._nowaveform_handler() + except (RuntimeError, FailedWaveformError) as e: + try: + ignore_failed = self.ignore_failed_waveforms + except AttributeError: + ignore_failed = False + if ignore_failed: + retval = self._nowaveform_handler() + else: + raise e + return retval + return method_wrapper + + class BaseGaussianNoise(BaseDataModel, metaclass=ABCMeta): r"""Model for analyzing GW data with assuming a wide-sense stationary Gaussian noise model. @@ -482,6 +517,18 @@ def _fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict): """Wrapper around :py:func:`data_utils.fd_data_from_strain_dict`.""" return fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict) + def _nowaveform_handler(self): + """Method that gets called if a NoWaveformError or FailedWaveformError + is raised. See the :py:func:catch_waveform_error decorator for details. + + Here, this will just raise a NotImplementedError, since how this should + be handled is model dependent. Models that wish to deal with this + scenario should override this method. + """ + raise NotImplementedError( + f"A waveform could not be generated, but this model does not know " + f"how to handle that. The parameters were: {self.current_params}.") + @classmethod def from_config(cls, cp, data_section='data', data=None, psds=None, **kwargs): @@ -872,7 +919,7 @@ def _extra_stats(self): ['{}_cplx_loglr'.format(det) for det in self._data] + \ ['{}_optimal_snrsq'.format(det) for det in self._data] - def _nowaveform_loglr(self): + def _nowaveform_handler(self): """Convenience function to set loglr values if no waveform generated. """ for det in self._data: @@ -890,6 +937,7 @@ def multi_signal_support(self): """ return [type(self)] + @catch_waveform_error def multi_loglikelihood(self, models): """ Calculate a multi-model (signal) likelihood """ @@ -931,6 +979,7 @@ def get_waveforms(self): self._current_wfs = wfs return self._current_wfs + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, @@ -947,16 +996,7 @@ def _loglr(self): float The value of the log likelihood ratio. """ - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e - + wfs = self.get_waveforms() lr = 0. for det, h in wfs.items(): # the kmax of the waveforms may be different than internal kmax diff --git a/pycbc/inference/models/marginalized_gaussian_noise.py b/pycbc/inference/models/marginalized_gaussian_noise.py index 05a402aa8cf..7496a9cbc80 100644 --- a/pycbc/inference/models/marginalized_gaussian_noise.py +++ b/pycbc/inference/models/marginalized_gaussian_noise.py @@ -24,11 +24,10 @@ from scipy import special from pycbc.waveform import generator -from pycbc.waveform import (NoWaveformError, FailedWaveformError) from pycbc.detector import Detector from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator, - GaussianNoise) + GaussianNoise, catch_waveform_error) from .tools import marginalize_likelihood, DistMarg @@ -129,7 +128,7 @@ def _extra_stats(self): return ['loglr', 'maxl_phase'] + \ ['{}_optimal_snrsq'.format(det) for det in self._data] - def _nowaveform_loglr(self): + def _nowaveform_handler(self): """Convenience function to set loglr values if no waveform generated. """ setattr(self._current_stats, 'loglikelihood', -numpy.inf) @@ -140,6 +139,7 @@ def _nowaveform_loglr(self): setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.) return -numpy.inf + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, .. math:: @@ -153,21 +153,12 @@ def _loglr(self): The value of the log likelihood ratio evaluated at the given point. """ params = self.current_params - try: - if self.all_ifodata_same_rate_length: - wfs = self.waveform_generator.generate(**params) - else: - wfs = {} - for det in self.data: - wfs.update(self.waveform_generator[det].generate(**params)) - - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e + if self.all_ifodata_same_rate_length: + wfs = self.waveform_generator.generate(**params) + else: + wfs = {} + for det in self.data: + wfs.update(self.waveform_generator[det].generate(**params)) hh = 0. hd = 0j for det, h in wfs.items(): @@ -254,11 +245,12 @@ def __init__(self, variable_params, logging.info("Using %s sample rate for marginalization", sample_rate) - def _nowaveform_loglr(self): + def _nowaveform_handler(self): """Convenience function to set loglr values if no waveform generated. """ return -numpy.inf + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, or inner product and if `self.return_sh_hh` is True. @@ -279,21 +271,12 @@ def _loglr(self): from pycbc.filter import matched_filter_core params = self.current_params - try: - if self.all_ifodata_same_rate_length: - wfs = self.waveform_generator.generate(**params) - else: - wfs = {} - for det in self.data: - wfs.update(self.waveform_generator[det].generate(**params)) - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e - + if self.all_ifodata_same_rate_length: + wfs = self.waveform_generator.generate(**params) + else: + wfs = {} + for det in self.data: + wfs.update(self.waveform_generator[det].generate(**params)) sh_total = hh_total = 0. snr_estimate = {} cplx_hpd = {} @@ -436,7 +419,7 @@ def _extra_stats(self): return ['loglr', 'maxl_polarization', 'maxl_loglr'] + \ ['{}_optimal_snrsq'.format(det) for det in self._data] - def _nowaveform_loglr(self): + def _nowaveform_handler(self): """Convenience function to set loglr values if no waveform generated. """ setattr(self._current_stats, 'loglr', -numpy.inf) @@ -447,6 +430,7 @@ def _nowaveform_loglr(self): setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.) return -numpy.inf + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, @@ -464,20 +448,12 @@ def _loglr(self): The value of the log likelihood ratio. """ params = self.current_params - try: - if self.all_ifodata_same_rate_length: - wfs = self.waveform_generator.generate(**params) - else: - wfs = {} - for det in self.data: - wfs.update(self.waveform_generator[det].generate(**params)) - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e + if self.all_ifodata_same_rate_length: + wfs = self.waveform_generator.generate(**params) + else: + wfs = {} + for det in self.data: + wfs.update(self.waveform_generator[det].generate(**params)) lr = sh_total = hh_total = 0. for det, (hp, hc) in wfs.items(): @@ -643,7 +619,7 @@ def _extra_stats(self): """ return ['maxl_polarization', 'maxl_phase', ] - def _nowaveform_loglr(self): + def _nowaveform_handler(self): """Convenience function to set loglr values if no waveform generated. """ # maxl phase doesn't exist, so set it to nan @@ -651,6 +627,7 @@ def _nowaveform_loglr(self): setattr(self._current_stats, 'maxl_phase', numpy.nan) return -numpy.inf + @catch_waveform_error def _loglr(self, return_unmarginalized=False): r"""Computes the log likelihood ratio, @@ -668,15 +645,7 @@ def _loglr(self, return_unmarginalized=False): The value of the log likelihood ratio. """ params = self.current_params - try: - wfs = self.waveform_generator.generate(**params) - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e + wfs = self.waveform_generator.generate(**params) # --------------------------------------------------------------------- # Some optimizations not yet taken: diff --git a/pycbc/inference/models/relbin.py b/pycbc/inference/models/relbin.py index 8c6be79a1ec..018260794f3 100644 --- a/pycbc/inference/models/relbin.py +++ b/pycbc/inference/models/relbin.py @@ -36,7 +36,8 @@ from pycbc.detector import Detector from pycbc.types import Array, TimeSeries -from .gaussian_noise import BaseGaussianNoise +from .gaussian_noise import (BaseGaussianNoise, catch_waveform_error) +from pycbc.waveform import FailedWaveformError from .relbin_cpu import (likelihood_parts, likelihood_parts_v, likelihood_parts_multi, likelihood_parts_multi_v, likelihood_parts_det, likelihood_parts_det_multi, @@ -523,6 +524,7 @@ def multi_loglikelihood(self, models): loglr += - h1h2.real # This is -0.5 * re( + ) return loglr + self.lognl + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, or inner product and if `self.return_sh_hh` is True. @@ -603,6 +605,17 @@ def _loglr(self): results = loglr return results + def _nowaveform_handler(self): + """Returns -inf for loglr if no waveform generated. + + If `return_sh_hh` is set to True, a FailedWaveformError will be raised. + """ + if self.return_sh_hh: + raise FailedWaveformError("Waveform failed to generate and " + "return_sh_hh set to True! I don't know " + "what to return in this case.") + return -numpy.inf + def write_metadata(self, fp, group=None): """Adds writing the fiducial parameters and epsilon to file's attrs. @@ -718,6 +731,7 @@ def get_snr(self, wfs): epoch=self.tstart[ifo] - delta_t * 2.0) return snrs + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, @@ -785,6 +799,11 @@ def _loglr(self): loglr = self.marginalize_loglr(filt, norm) return loglr + def _nowaveform_handler(self): + """Sets loglr values if no waveform generated. + """ + return -numpy.inf + class RelativeTimeDom(RelativeTime): """ Heterodyne likelihood optimized for time marginalization and only @@ -820,6 +839,7 @@ def get_snr(self, wfs): return snrs + @catch_waveform_error def _loglr(self): r"""Computes the log likelihood ratio, or inner product and if `self.return_sh_hh` is True. @@ -881,3 +901,13 @@ def _loglr(self): else: results = loglr return results + + def _nowaveform_handler(self): + """Sets loglr values if no waveform generated. + """ + loglr = sh_total = hh_total = -numpy.inf + if self.return_sh_hh: + results = (sh_total, hh_total) + else: + results = loglr + return results diff --git a/test/test_infmodel.py b/test/test_infmodel.py index 9d4555c04c0..3a779f6165c 100644 --- a/test/test_infmodel.py +++ b/test/test_infmodel.py @@ -27,13 +27,16 @@ import unittest import copy from utils import simple_exit +import numpy from pycbc.catalog import Merger -from pycbc.psd import interpolate, inverse_spectrum_truncation +from pycbc.psd import interpolate, inverse_spectrum_truncation, aLIGOZeroDetHighPower +from pycbc.noise import noise_from_psd from pycbc.frame import read_frame from pycbc.filter import highpass, resample_to_delta_t from astropy.utils.data import download_file from pycbc.inference import models from pycbc.distributions import Uniform, JointDistribution, SinAngle, UniformAngle +from pycbc.waveform.waveform import FailedWaveformError class TestModels(unittest.TestCase): @@ -166,8 +169,181 @@ def test_brute_pol_phase_marg(self): model.update(**self.q1) self.assertAlmostEqual(self.a2, model.loglr, delta=0.002) + +class TestWaveformErrors(unittest.TestCase): + """Tests that models handle no waveform errors correctly.""" + + @classmethod + def setUpClass(cls): + cls.psds = {} + cls.data = {} + # static params for the test + tc = 1187008882.42840 + flow = 20 + cls.static = { + 'approximant':"IMRPhenomD", + 'mass1': 40., + 'mass2': 40., + 'polarization': 0, + 'ra': 3.44615914, + 'dec': -0.40808407, + 'tc': tc, + 'distance': 100., + 'inclination': 2.5 + } + cls.variable = ['spin1z', 'f_lower'] + ifos = ['H1', 'L1', 'V1'] + # generate the reference psd + seglen = 8 + delta_f = 1./seglen + sample_rate = 4096 + delta_t = 1./sample_rate + flen = int(sample_rate * seglen / 2) + 1 + psd = aLIGOZeroDetHighPower(flen, delta_f, flow) + # put non-zero values in the beginning and end of the psd + # so the gating models will work + psd[0:int(flow/delta_f+1)] = psd[int(flow/delta_f+1)] + psd[-2:] = psd[-2] + seed = 1000 + cls.flow = {'H1': flow, 'L1': flow, 'V1': flow} + # generate the noise + for ifo in ifos: + tsamples = int(seglen * sample_rate) + ts = noise_from_psd(tsamples, delta_t, psd, seed=seed) + ts._epoch = cls.static['tc'] - seglen/2 + seed += 1027 + cls.data[ifo] = ts.to_frequencyseries() + cls.psds[ifo] = psd + # setup priors + spin_prior = Uniform(spin1z=(-1., 2.)) + flowbad = 4000. + flower_prior = Uniform(f_lower=(flow, flowbad+100.)) + pol = UniformAngle(polarization=None) + cls.prior = JointDistribution(cls.variable, spin_prior, flower_prior) + + # set up for marginalized polarization tests + cls.static2 = cls.static.copy() + cls.static2.pop('polarization') + cls.variable2 = cls.variable + ['polarization'] + cls.prior2 = JointDistribution(cls.variable2, spin_prior, flower_prior, + pol) + # set up gated parameters + staticgate = cls.static.copy() + staticgate['t_gate_start'] = tc - 0.05 + staticgate['t_gate_end'] = tc + cls.staticgate = staticgate + # margpol + staticgate2 = cls.static2.copy() + staticgate2['t_gate_start'] = tc - 0.05 + staticgate2['t_gate_end'] = tc + cls.staticgate2 = staticgate2 + # the parameters to test: + # these parameters should pass + cls.pass_params = {'spin1z': 0., 'f_lower': flow} + # these parameters should trigger a NoWaveformError + cls.nowf_params = {'spin1z': 0., 'f_lower': flowbad} + # these parameters should cause a FailedWaveformError + cls.fail_params = {'spin1z': 2., 'f_lower': flow} + + def _run_tests(self, model, check_pass=True, check_nowf=True, + check_failed=True, check_raises=True): + # check that the model works + if check_pass: + model.update(**self.pass_params) + self.assertTrue(numpy.isfinite(model.loglr)) + # check that a no waveform error is caught correctly + if check_nowf: + model.update(**self.nowf_params) + self.assertEqual(model.loglr, -numpy.inf) + # check that a failed waveform is caught correctly + if check_failed: + model.update(**self.fail_params) + self.assertEqual(model.loglr, -numpy.inf) + # check that an error is raised if ignore_failed_waveforms is False + if check_raises: + model.ignore_failed_waveforms = False + model.update(**self.fail_params) + with self.assertRaises((FailedWaveformError, RuntimeError)): + model.loglr + + def test_base_phase_marg(self): + model = models.MarginalizedPhaseGaussianNoise( + self.variable, copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds=self.psds, + static_params=self.static, + prior=self.prior, + ignore_failed_waveforms=True) + self._run_tests(model) + + def test_relative_phase_marg(self): + model = models.Relative(self.variable, copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds = self.psds, + static_params = self.static, + prior = self.prior, + fiducial_params = {}, + #fiducial_params = {'mass1': 40.}, + epsilon = .1, + ignore_failed_waveforms=True) + # relative model doesn't respect flower, so no point in testing nowf + self._run_tests(model, check_nowf=False) + + def test_brute_pol_phase_marg(self): + # Uses the old polarization syntax untill we decide to remove it. + # Untill then, this also tests that that interface stays working. + model = models.BruteParallelGaussianMarginalize( + self.variable, data=copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds = self.psds, + static_params = self.static2, + prior = self.prior, + marginalize_phase=4, + cores=1, + base_model='marginalized_polarization', + ignore_failed_waveforms=True + ) + # we need to do the check raises test separately because the underlying + # base model's ignore_failed_waveforms needs to be set + self._run_tests(model, check_raises=False) + model = models.BruteParallelGaussianMarginalize( + self.variable, data=copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds = self.psds, + static_params = self.static2, + prior = self.prior, + marginalize_phase=4, + cores=1, + base_model='marginalized_polarization', + ignore_failed_waveforms=False + ) + self._run_tests(model, check_pass=False, check_nowf=False, + check_failed=False, check_raises=True) + + def test_gated_gaussian_noise(self): + model = models.GatedGaussianNoise( + self.variable, data=copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds=self.psds, + static_params=self.staticgate, + prior=self.prior, + ignore_failed_waveforms=True) + self._run_tests(model) + + def test_gated_gaussian_margpol(self): + model = models.GatedGaussianMargPol( + self.variable, data=copy.deepcopy(self.data), + low_frequency_cutoff=self.flow, + psds=self.psds, + static_params=self.staticgate2, + prior=self.prior, + ignore_failed_waveforms=True) + self._run_tests(model) + + suite = unittest.TestSuite() suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestModels)) +suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestWaveformErrors)) if __name__ == '__main__': from astropy.utils import iers