Skip to content

Commit

Permalink
Add waveform error handling to Relbin (and other models) (#4996)
Browse files Browse the repository at this point in the history
* catch failed waveforms by using a decorator

* add ability to catch failed waveforms to relbin models

* add RuntimeError to list of errors caught

* add unittests

* fix issues found by unit tests

* raise an error in Relative model if return_sh_hh set to True

---------

Co-authored-by: Collin Capano <[email protected]>
  • Loading branch information
cdcapano and Collin Capano authored Dec 18, 2024
1 parent 91994c8 commit 242620a
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 100 deletions.
37 changes: 9 additions & 28 deletions pycbc/inference/models/gated_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import numpy
from scipy import special

from pycbc.waveform import (NoWaveformError, FailedWaveformError)
from pycbc.types import FrequencySeries
from pycbc.detector import Detector
from pycbc.pnutils import hybrid_meco_frequency
from pycbc.waveform.utils import time_from_frequencyseries
from pycbc.waveform import generator
from pycbc.filter import highpass
from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator)
from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator,
catch_waveform_error)
from .base_data import BaseDataModel
from .data_utils import fd_data_from_strain_dict

Expand Down Expand Up @@ -134,8 +134,7 @@ def normalize(self, normalize):
"""
self._normalize = normalize

@staticmethod
def _nowaveform_logl():
def _nowaveform_handler(self):
"""Convenience function to set logl values if no waveform generated.
"""
return -numpy.inf
Expand Down Expand Up @@ -329,20 +328,14 @@ def get_gate_times(self):

def get_gate_times_hmeco(self):
"""Gets the time to apply a gate based on the current sky position.
Returns
-------
dict :
Dictionary of detector names -> (gate start, gate width)
"""
# generate the template waveform
try:
wfs = self.get_waveforms()
except NoWaveformError:
return self._nowaveform_logl()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_logl()
raise e
wfs = self.get_waveforms()
# get waveform parameters
params = self.current_params
spin1 = params['spin1z']
Expand Down Expand Up @@ -514,6 +507,7 @@ def _extra_stats(self):
"""No extra stats are stored."""
return []

@catch_waveform_error
def _loglikelihood(self):
r"""Computes the log likelihood after removing the power within the
given time window,
Expand All @@ -530,14 +524,7 @@ def _loglikelihood(self):
The value of the log likelihood.
"""
# generate the template waveform
try:
wfs = self.get_waveforms()
except NoWaveformError:
return self._nowaveform_logl()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_logl()
raise e
wfs = self.get_waveforms()
# get the times of the gates
gate_times = self.get_gate_times()
logl = 0.
Expand Down Expand Up @@ -681,6 +668,7 @@ def _extra_stats(self):
"""Adds the maxL polarization and corresponding likelihood."""
return ['maxl_polarization', 'maxl_logl']

@catch_waveform_error
def _loglikelihood(self):
r"""Computes the log likelihood after removing the power within the
given time window,
Expand All @@ -697,14 +685,7 @@ def _loglikelihood(self):
The value of the log likelihood.
"""
# generate the template waveform
try:
wfs = self.get_waveforms()
except NoWaveformError:
return self._nowaveform_logl()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_logl()
raise e
wfs = self.get_waveforms()
# get the gated waveforms and data
gated_wfs = self.get_gated_waveforms()
gated_data = self.get_gated_data()
Expand Down
62 changes: 51 additions & 11 deletions pycbc/inference/models/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import shlex
from abc import ABCMeta
from functools import wraps
import numpy

from pycbc import filter as pyfilter
Expand All @@ -37,6 +38,40 @@
fd_data_from_strain_dict, gate_overwhitened_data)


def catch_waveform_error(method):
"""Decorator that will catch no waveform errors.
This can be added to a method in an inference model. The decorator will
call the model's `_nowaveform_return` method if either of the following
happens when the wrapped method is executed:
* A `NoWaveformError` is raised.
* A `RuntimeError` or `FailedWaveformError` is raised and the model has
an `ignore_failed_waveforms` attribute that is set to True.
This requires the model to have a `_nowaveform_handler` method.
"""
# the functools.wroaps decorator preserves the original method's name
# and docstring
@wraps(method)
def method_wrapper(self, *args, **kwargs):
try:
retval = method(self, *args, **kwargs)
except NoWaveformError:
retval = self._nowaveform_handler()
except (RuntimeError, FailedWaveformError) as e:
try:
ignore_failed = self.ignore_failed_waveforms
except AttributeError:
ignore_failed = False
if ignore_failed:
retval = self._nowaveform_handler()
else:
raise e
return retval
return method_wrapper


class BaseGaussianNoise(BaseDataModel, metaclass=ABCMeta):
r"""Model for analyzing GW data with assuming a wide-sense stationary
Gaussian noise model.
Expand Down Expand Up @@ -482,6 +517,18 @@ def _fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict):
"""Wrapper around :py:func:`data_utils.fd_data_from_strain_dict`."""
return fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict)

def _nowaveform_handler(self):
"""Method that gets called if a NoWaveformError or FailedWaveformError
is raised. See the :py:func:catch_waveform_error decorator for details.
Here, this will just raise a NotImplementedError, since how this should
be handled is model dependent. Models that wish to deal with this
scenario should override this method.
"""
raise NotImplementedError(
f"A waveform could not be generated, but this model does not know "
f"how to handle that. The parameters were: {self.current_params}.")

@classmethod
def from_config(cls, cp, data_section='data', data=None, psds=None,
**kwargs):
Expand Down Expand Up @@ -872,7 +919,7 @@ def _extra_stats(self):
['{}_cplx_loglr'.format(det) for det in self._data] + \
['{}_optimal_snrsq'.format(det) for det in self._data]

def _nowaveform_loglr(self):
def _nowaveform_handler(self):
"""Convenience function to set loglr values if no waveform generated.
"""
for det in self._data:
Expand All @@ -890,6 +937,7 @@ def multi_signal_support(self):
"""
return [type(self)]

@catch_waveform_error
def multi_loglikelihood(self, models):
""" Calculate a multi-model (signal) likelihood
"""
Expand Down Expand Up @@ -931,6 +979,7 @@ def get_waveforms(self):
self._current_wfs = wfs
return self._current_wfs

@catch_waveform_error
def _loglr(self):
r"""Computes the log likelihood ratio,
Expand All @@ -947,16 +996,7 @@ def _loglr(self):
float
The value of the log likelihood ratio.
"""
try:
wfs = self.get_waveforms()
except NoWaveformError:
return self._nowaveform_loglr()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_loglr()
else:
raise e

wfs = self.get_waveforms()
lr = 0.
for det, h in wfs.items():
# the kmax of the waveforms may be different than internal kmax
Expand Down
87 changes: 28 additions & 59 deletions pycbc/inference/models/marginalized_gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
from scipy import special

from pycbc.waveform import generator
from pycbc.waveform import (NoWaveformError, FailedWaveformError)
from pycbc.detector import Detector
from .gaussian_noise import (BaseGaussianNoise,
create_waveform_generator,
GaussianNoise)
GaussianNoise, catch_waveform_error)
from .tools import marginalize_likelihood, DistMarg


Expand Down Expand Up @@ -129,7 +128,7 @@ def _extra_stats(self):
return ['loglr', 'maxl_phase'] + \
['{}_optimal_snrsq'.format(det) for det in self._data]

def _nowaveform_loglr(self):
def _nowaveform_handler(self):
"""Convenience function to set loglr values if no waveform generated.
"""
setattr(self._current_stats, 'loglikelihood', -numpy.inf)
Expand All @@ -140,6 +139,7 @@ def _nowaveform_loglr(self):
setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.)
return -numpy.inf

@catch_waveform_error
def _loglr(self):
r"""Computes the log likelihood ratio,
.. math::
Expand All @@ -153,21 +153,12 @@ def _loglr(self):
The value of the log likelihood ratio evaluated at the given point.
"""
params = self.current_params
try:
if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))

except NoWaveformError:
return self._nowaveform_loglr()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_loglr()
else:
raise e
if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))
hh = 0.
hd = 0j
for det, h in wfs.items():
Expand Down Expand Up @@ -254,11 +245,12 @@ def __init__(self, variable_params,
logging.info("Using %s sample rate for marginalization",
sample_rate)

def _nowaveform_loglr(self):
def _nowaveform_handler(self):
"""Convenience function to set loglr values if no waveform generated.
"""
return -numpy.inf

@catch_waveform_error
def _loglr(self):
r"""Computes the log likelihood ratio,
or inner product <s|h> and <h|h> if `self.return_sh_hh` is True.
Expand All @@ -279,21 +271,12 @@ def _loglr(self):
from pycbc.filter import matched_filter_core

params = self.current_params
try:
if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))
except NoWaveformError:
return self._nowaveform_loglr()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_loglr()
else:
raise e

if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))
sh_total = hh_total = 0.
snr_estimate = {}
cplx_hpd = {}
Expand Down Expand Up @@ -436,7 +419,7 @@ def _extra_stats(self):
return ['loglr', 'maxl_polarization', 'maxl_loglr'] + \
['{}_optimal_snrsq'.format(det) for det in self._data]

def _nowaveform_loglr(self):
def _nowaveform_handler(self):
"""Convenience function to set loglr values if no waveform generated.
"""
setattr(self._current_stats, 'loglr', -numpy.inf)
Expand All @@ -447,6 +430,7 @@ def _nowaveform_loglr(self):
setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.)
return -numpy.inf

@catch_waveform_error
def _loglr(self):
r"""Computes the log likelihood ratio,
Expand All @@ -464,20 +448,12 @@ def _loglr(self):
The value of the log likelihood ratio.
"""
params = self.current_params
try:
if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))
except NoWaveformError:
return self._nowaveform_loglr()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_loglr()
else:
raise e
if self.all_ifodata_same_rate_length:
wfs = self.waveform_generator.generate(**params)
else:
wfs = {}
for det in self.data:
wfs.update(self.waveform_generator[det].generate(**params))

lr = sh_total = hh_total = 0.
for det, (hp, hc) in wfs.items():
Expand Down Expand Up @@ -643,14 +619,15 @@ def _extra_stats(self):
"""
return ['maxl_polarization', 'maxl_phase', ]

def _nowaveform_loglr(self):
def _nowaveform_handler(self):
"""Convenience function to set loglr values if no waveform generated.
"""
# maxl phase doesn't exist, so set it to nan
setattr(self._current_stats, 'maxl_polarization', numpy.nan)
setattr(self._current_stats, 'maxl_phase', numpy.nan)
return -numpy.inf

@catch_waveform_error
def _loglr(self, return_unmarginalized=False):
r"""Computes the log likelihood ratio,
Expand All @@ -668,15 +645,7 @@ def _loglr(self, return_unmarginalized=False):
The value of the log likelihood ratio.
"""
params = self.current_params
try:
wfs = self.waveform_generator.generate(**params)
except NoWaveformError:
return self._nowaveform_loglr()
except FailedWaveformError as e:
if self.ignore_failed_waveforms:
return self._nowaveform_loglr()
else:
raise e
wfs = self.waveform_generator.generate(**params)

# ---------------------------------------------------------------------
# Some optimizations not yet taken:
Expand Down
Loading

0 comments on commit 242620a

Please sign in to comment.