From 279eedf00d12fe2e84d4488a05ef4d203ed10e60 Mon Sep 17 00:00:00 2001 From: Collin Capano Date: Wed, 18 Dec 2024 10:32:31 -0500 Subject: [PATCH] fix issues found by unit tests --- pycbc/inference/models/gaussian_noise.py | 69 +++++------ test/test_infmodel.py | 147 ++++++++++++----------- 2 files changed, 112 insertions(+), 104 deletions(-) diff --git a/pycbc/inference/models/gaussian_noise.py b/pycbc/inference/models/gaussian_noise.py index 760d99a30b0..1442d7ae7e9 100644 --- a/pycbc/inference/models/gaussian_noise.py +++ b/pycbc/inference/models/gaussian_noise.py @@ -38,6 +38,40 @@ fd_data_from_strain_dict, gate_overwhitened_data) +def catch_waveform_error(method): + """Decorator that will catch no waveform errors. + + This can be added to a method in an inference model. The decorator will + call the model's `_nowaveform_return` method if either of the following + happens when the wrapped method is executed: + + * A `NoWaveformError` is raised. + * A `RuntimeError` or `FailedWaveformError` is raised and the model has + an `ignore_failed_waveforms` attribute that is set to True. + + This requires the model to have a `_nowaveform_handler` method. + """ + # the functools.wroaps decorator preserves the original method's name + # and docstring + @wraps(method) + def method_wrapper(self, *args, **kwargs): + try: + retval = method(self, *args, **kwargs) + except NoWaveformError: + retval = self._nowaveform_handler() + except (RuntimeError, FailedWaveformError) as e: + try: + ignore_failed = self.ignore_failed_waveforms + except AttributeError: + ignore_failed = False + if ignore_failed: + retval = self._nowaveform_handler() + else: + raise e + return retval + return method_wrapper + + class BaseGaussianNoise(BaseDataModel, metaclass=ABCMeta): r"""Model for analyzing GW data with assuming a wide-sense stationary Gaussian noise model. @@ -493,7 +527,7 @@ def _nowaveform_handler(self): """ raise NotImplementedError( f"A waveform could not be generated, but this model does not know " - f"how to handle that. The parameters were: {self.current_params}".) + f"how to handle that. The parameters were: {self.current_params}.") @classmethod def from_config(cls, cp, data_section='data', data=None, psds=None, @@ -1232,36 +1266,3 @@ def create_waveform_generator( recalib=recalibration, gates=gates, **static_params) return waveform_generator - - -def catch_waveform_error(method): - """Decorator that will catch no waveform errors. - - This can be added to a method in an inference model. The decorator will - call the model's `_nowaveform_return` method if either of the following - happens when the wrapped method is executed: - - * A `NoWaveformError` is raised. - * A `RuntimeError` or `FailedWaveformError` is raised and the model has - an `ignore_failed_waveforms` attribute that is set to True. - - This requires the model to have a `_nowaveform_return` method. - """ - # the functools.wroaps decorator preserves the original method's name - # and docstring - @wraps(method) - def method_wrapper(self, *args, **kwargs): - try: - retval = method(self, *args, **kwargs) - except NoWaveformError: - retval = self._nowaveform_return() - except (RuntimeError, FailedWaveformError) as e: - try: - ignore_failed = self.ignore_failed_waveforms - except AttributeError: - ignore_failed = False - if ignore_failed: - retval = self._nowaveform_return() - raise e - return retval - return method_wrapper diff --git a/test/test_infmodel.py b/test/test_infmodel.py index b68ab96734b..3a779f6165c 100644 --- a/test/test_infmodel.py +++ b/test/test_infmodel.py @@ -178,84 +178,93 @@ def setUpClass(cls): cls.psds = {} cls.data = {} # static params for the test + tc = 1187008882.42840 + flow = 20 cls.static = { - 'f_final_func': 'SchwarzISCO', - 'approximant':"TaylorF2", + 'approximant':"IMRPhenomD", + 'mass1': 40., + 'mass2': 40., 'polarization': 0, 'ra': 3.44615914, 'dec': -0.40808407, - 'tc': 1187008882.42840, - 'distance': 42., + 'tc': tc, + 'distance': 100., 'inclination': 2.5 } - cls.variable = ['mass1', 'mass2', 'f_lower'] + cls.variable = ['spin1z', 'f_lower'] ifos = ['H1', 'L1', 'V1'] # generate the reference psd - flow = 20 - seglen = 256 + seglen = 8 delta_f = 1./seglen sample_rate = 4096 delta_t = 1./sample_rate flen = int(sample_rate * seglen / 2) + 1 psd = aLIGOZeroDetHighPower(flen, delta_f, flow) + # put non-zero values in the beginning and end of the psd + # so the gating models will work + psd[0:int(flow/delta_f+1)] = psd[int(flow/delta_f+1)] + psd[-2:] = psd[-2] seed = 1000 cls.flow = {'H1': flow, 'L1': flow, 'V1': flow} # generate the noise for ifo in ifos: - print("Generating {} noise".format(ifo)) - tsamples = int(seglen * sample_rate) ts = noise_from_psd(tsamples, delta_t, psd, seed=seed) ts._epoch = cls.static['tc'] - seglen/2 seed += 1027 - cls.data[ifo] = ts + cls.data[ifo] = ts.to_frequencyseries() cls.psds[ifo] = psd # setup priors - inclination_prior = SinAngle(inclination=None) - distance_prior = Uniform(distance=(10, 100)) - tc_prior = Uniform(tc=(m.time-0.1, m.time+0.1)) + spin_prior = Uniform(spin1z=(-1., 2.)) + flowbad = 4000. + flower_prior = Uniform(f_lower=(flow, flowbad+100.)) pol = UniformAngle(polarization=None) - cls.prior = JointDistribution(cls.variable, inclination_prior, - distance_prior) + cls.prior = JointDistribution(cls.variable, spin_prior, flower_prior) # set up for marginalized polarization tests cls.static2 = cls.static.copy() cls.static2.pop('polarization') cls.variable2 = cls.variable + ['polarization'] - cls.prior2 = JointDistribution(cls.variable2, inclination_prior, - distance_prior, pol) + cls.prior2 = JointDistribution(cls.variable2, spin_prior, flower_prior, + pol) # set up gated parameters staticgate = cls.static.copy() - staticgate['t_gate_start'] = static['tc'] - 1. - staticgate['t_gate_end'] = static['tc'] - self.staticgate = staticgate + staticgate['t_gate_start'] = tc - 0.05 + staticgate['t_gate_end'] = tc + cls.staticgate = staticgate # margpol staticgate2 = cls.static2.copy() - staticgate2['t_gate_start'] = static2['tc'] - 1. - staticgate2['t_gate_end'] = static2['tc'] - self.staticgate2 = staticgate2 + staticgate2['t_gate_start'] = tc - 0.05 + staticgate2['t_gate_end'] = tc + cls.staticgate2 = staticgate2 # the parameters to test: # these parameters should pass - cls.pass_params = {'mass1': 1.4, 'mass2': 1.4, 'f_lower': flow} + cls.pass_params = {'spin1z': 0., 'f_lower': flow} # these parameters should trigger a NoWaveformError - cls.nowf_params = {'mass1': 1.4, 'mass2': 1.4, 'f_lower': 4000.} + cls.nowf_params = {'spin1z': 0., 'f_lower': flowbad} # these parameters should cause a FailedWaveformError - cls.fail_params = {'mass1': -1.4, 'mass2': 1.4, 'f_lower': flow} - - def _run_tests(self, model): - # first check that the model works - model.update(**self.pass_params) - self.assertTrue(numpy.isfinite(model.loglr)) - # now check that a no waveform error is caught correctly - model.update(**self.nowf_params) - self.assertEqual(model.loglr == -numpy.inf) - # now check that a failed waveform is caught correctly - model.update(**self.fail_params) - self.assertEqual(model.loglr == -numpy.inf) - # now check that an error is raised if ignore_failed_waveforms is False - model.ignore_failed_waveforms = False - model.update(**self.fail_params) - model.assertRaises((FailedWaveformError, RuntimeError), model.loglr) + cls.fail_params = {'spin1z': 2., 'f_lower': flow} + + def _run_tests(self, model, check_pass=True, check_nowf=True, + check_failed=True, check_raises=True): + # check that the model works + if check_pass: + model.update(**self.pass_params) + self.assertTrue(numpy.isfinite(model.loglr)) + # check that a no waveform error is caught correctly + if check_nowf: + model.update(**self.nowf_params) + self.assertEqual(model.loglr, -numpy.inf) + # check that a failed waveform is caught correctly + if check_failed: + model.update(**self.fail_params) + self.assertEqual(model.loglr, -numpy.inf) + # check that an error is raised if ignore_failed_waveforms is False + if check_raises: + model.ignore_failed_waveforms = False + model.update(**self.fail_params) + with self.assertRaises((FailedWaveformError, RuntimeError)): + model.loglr def test_base_phase_marg(self): model = models.MarginalizedPhaseGaussianNoise( @@ -273,47 +282,43 @@ def test_relative_phase_marg(self): psds = self.psds, static_params = self.static, prior = self.prior, - fiducial_params = {'mass1':1.3756}, + fiducial_params = {}, + #fiducial_params = {'mass1': 40.}, epsilon = .1, - ) - self._run_tests(model) + ignore_failed_waveforms=True) + # relative model doesn't respect flower, so no point in testing nowf + self._run_tests(model, check_nowf=False) - def test_single_phase_marg(self): - model = models.SingleTemplate( - self.variable, copy.deepcopy(self.data), - low_frequency_cutoff=self.flow, - psds = self.psds, - static_params = self.static, - prior = self.prior, - ) - self._run_tests(model) - - def test_single_pol_phase_marg(self): - model = models.SingleTemplate( - self.variable2, copy.deepcopy(self.data), + def test_brute_pol_phase_marg(self): + # Uses the old polarization syntax untill we decide to remove it. + # Untill then, this also tests that that interface stays working. + model = models.BruteParallelGaussianMarginalize( + self.variable, data=copy.deepcopy(self.data), low_frequency_cutoff=self.flow, psds = self.psds, static_params = self.static2, - prior = self.prior2, - marginalize_vector_samples = 1000, - marginalize_vector_params = 'polarization', + prior = self.prior, + marginalize_phase=4, + cores=1, + base_model='marginalized_polarization', + ignore_failed_waveforms=True ) - self._run_tests(model) - - def test_brute_pol_phase_marg(self): - # Uses the old polarization syntax untill we decide to remove it. - # Untill then, this also tests that that interface stays working. + # we need to do the check raises test separately because the underlying + # base model's ignore_failed_waveforms needs to be set + self._run_tests(model, check_raises=False) model = models.BruteParallelGaussianMarginalize( self.variable, data=copy.deepcopy(self.data), low_frequency_cutoff=self.flow, psds = self.psds, static_params = self.static2, prior = self.prior, - marginalize_phase=400, + marginalize_phase=4, cores=1, base_model='marginalized_polarization', + ignore_failed_waveforms=False ) - self._run_tests(model) + self._run_tests(model, check_pass=False, check_nowf=False, + check_failed=False, check_raises=True) def test_gated_gaussian_noise(self): model = models.GatedGaussianNoise( @@ -321,16 +326,18 @@ def test_gated_gaussian_noise(self): low_frequency_cutoff=self.flow, psds=self.psds, static_params=self.staticgate, - prior=self.prior) + prior=self.prior, + ignore_failed_waveforms=True) self._run_tests(model) def test_gated_gaussian_margpol(self): - model = models.GatedGaussianNoise( + model = models.GatedGaussianMargPol( self.variable, data=copy.deepcopy(self.data), low_frequency_cutoff=self.flow, psds=self.psds, static_params=self.staticgate2, - prior=self.prior) + prior=self.prior, + ignore_failed_waveforms=True) self._run_tests(model)