From 1091d06741d296ae958a88697f90266987a75b24 Mon Sep 17 00:00:00 2001 From: acorreia61201 Date: Thu, 5 Oct 2023 18:06:14 +0000 Subject: [PATCH] Cleaning up before PR --- .../.ipynb_checkpoints/__init__-checkpoint.py | 3 - .../.ipynb_checkpoints/burn_in-checkpoint.py | 798 ----------- .../.ipynb_checkpoints/entropy-checkpoint.py | 242 ---- .../.ipynb_checkpoints/__init__-checkpoint.py | 347 ----- .../.ipynb_checkpoints/base-checkpoint.py | 886 ------------ .../base_data-checkpoint.py | 160 --- .../data_utils-checkpoint.py | 554 -------- .../gated_gaussian_noise-checkpoint.py | 1110 --------------- .../gated_gaussian_noise-mod-checkpoint.py | 911 ------------- .../gaussian_noise-checkpoint.py | 1202 ----------------- .../hierarchical-checkpoint.py | 566 -------- .../.ipynb_checkpoints/dynesty-checkpoint.py | 649 --------- .../scatter_histograms-checkpoint.py | 867 ------------ .../.ipynb_checkpoints/gate-checkpoint.py | 183 --- 14 files changed, 8478 deletions(-) delete mode 100644 pycbc/inference/.ipynb_checkpoints/__init__-checkpoint.py delete mode 100644 pycbc/inference/.ipynb_checkpoints/burn_in-checkpoint.py delete mode 100644 pycbc/inference/.ipynb_checkpoints/entropy-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/__init__-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/base-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/base_data-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/data_utils-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-mod-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/gaussian_noise-checkpoint.py delete mode 100644 pycbc/inference/models/.ipynb_checkpoints/hierarchical-checkpoint.py delete mode 100644 pycbc/inference/sampler/.ipynb_checkpoints/dynesty-checkpoint.py delete mode 100644 pycbc/results/.ipynb_checkpoints/scatter_histograms-checkpoint.py delete mode 100644 pycbc/strain/.ipynb_checkpoints/gate-checkpoint.py diff --git a/pycbc/inference/.ipynb_checkpoints/__init__-checkpoint.py b/pycbc/inference/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index 55c3d58ec84..00000000000 --- a/pycbc/inference/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,3 +0,0 @@ -# pylint: disable=unused-import -from . import (models, sampler, io) -from . import (burn_in, entropy, gelman_rubin, geweke, option_utils) diff --git a/pycbc/inference/.ipynb_checkpoints/burn_in-checkpoint.py b/pycbc/inference/.ipynb_checkpoints/burn_in-checkpoint.py deleted file mode 100644 index b5cd2970a06..00000000000 --- a/pycbc/inference/.ipynb_checkpoints/burn_in-checkpoint.py +++ /dev/null @@ -1,798 +0,0 @@ -# Copyright (C) 2017 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# -""" -This modules provides classes and functions for determining when Markov Chains -have burned in. -""" - - -import logging -from abc import ABCMeta, abstractmethod -import numpy -from scipy.stats import ks_2samp - -from pycbc.io.record import get_vars_from_arg - -# The value to use for a burn-in iteration if a chain is not burned in -NOT_BURNED_IN_ITER = -1 - - -# -# ============================================================================= -# -# Convenience functions -# -# ============================================================================= -# - - -def ks_test(samples1, samples2, threshold=0.9): - """Applies a KS test to determine if two sets of samples are the same. - - The ks test is applied parameter-by-parameter. If the two-tailed p-value - returned by the test is greater than ``threshold``, the samples are - considered to be the same. - - Parameters - ---------- - samples1 : dict - Dictionary of mapping parameters to the first set of samples. - samples2 : dict - Dictionary of mapping parameters to the second set of samples. - threshold : float - The thershold to use for the p-value. Default is 0.9. - - Returns - ------- - dict : - Dictionary mapping parameter names to booleans indicating whether the - given parameter passes the KS test. - """ - is_the_same = {} - assert set(samples1.keys()) == set(samples2.keys()), ( - "samples1 and 2 must have the same parameters") - # iterate over the parameters - for param in samples1: - s1 = samples1[param] - s2 = samples2[param] - _, p_value = ks_2samp(s1, s2) - is_the_same[param] = p_value > threshold - return is_the_same - - -def max_posterior(lnps_per_walker, dim): - """Burn in based on samples being within dim/2 of maximum posterior. - - Parameters - ---------- - lnps_per_walker : 2D array - Array of values that are proportional to the log posterior values. Must - have shape ``nwalkers x niterations``. - dim : int - The dimension of the parameter space. - - Returns - ------- - burn_in_idx : array of int - The burn in indices of each walker. If a walker is not burned in, its - index will be be equal to the length of the chain. - is_burned_in : array of bool - Whether or not a walker is burned in. - """ - if len(lnps_per_walker.shape) != 2: - raise ValueError("lnps_per_walker must have shape " - "nwalkers x niterations") - # find the value to compare against - max_p = lnps_per_walker.max() - criteria = max_p - dim/2. - nwalkers, _ = lnps_per_walker.shape - burn_in_idx = numpy.empty(nwalkers, dtype=int) - is_burned_in = numpy.empty(nwalkers, dtype=bool) - # find the first iteration in each chain where the logpost has exceeded - # max_p - dim/2 - for ii in range(nwalkers): - chain = lnps_per_walker[ii, :] - passedidx = numpy.where(chain >= criteria)[0] - is_burned_in[ii] = passedidx.size > 0 - if is_burned_in[ii]: - burn_in_idx[ii] = passedidx[0] - else: - burn_in_idx[ii] = NOT_BURNED_IN_ITER - return burn_in_idx, is_burned_in - - -def posterior_step(logposts, dim): - """Finds the last time a chain made a jump > dim/2. - - Parameters - ---------- - logposts : array - 1D array of values that are proportional to the log posterior values. - dim : int - The dimension of the parameter space. - - Returns - ------- - int - The index of the last time the logpost made a jump > dim/2. If that - never happened, returns 0. - """ - if logposts.ndim > 1: - raise ValueError("logposts must be a 1D array") - criteria = dim/2. - dp = numpy.diff(logposts) - indices = numpy.where(dp >= criteria)[0] - if indices.size > 0: - idx = indices[-1] + 1 - else: - idx = 0 - return idx - - -def nacl(nsamples, acls, nacls=5): - """Burn in based on ACL. - - This applies the following test to determine burn in: - - 1. The first half of the chain is ignored. - - 2. An ACL is calculated from the second half. - - 3. If ``nacls`` times the ACL is < the length of the chain / 2, - the chain is considered to be burned in at the half-way point. - - Parameters - ---------- - nsamples : int - The number of samples of in the chain(s). - acls : dict - Dictionary of parameter -> ACL(s). The ACLs for each parameter may - be an integer or an array of integers (for multiple chains). - nacls : int, optional - The number of ACLs the chain(s) must have gone past the halfway point - in order to be considered burned in. Default is 5. - - Returns - ------- - dict - Dictionary of parameter -> boolean(s) indicating if the chain(s) pass - the test. If an array of values was provided for the acls, the values - will be arrays of booleans. - """ - kstart = int(nsamples / 2.) - return {param: (nacls * acl) < kstart for (param, acl) in acls.items()} - - -def evaluate_tests(burn_in_test, test_is_burned_in, test_burn_in_iter): - """Evaluates burn in data from multiple tests. - - The iteration to use for burn-in depends on the logic in the burn-in - test string. For example, if the test was 'max_posterior | nacl' and - max_posterior burned-in at iteration 5000 while nacl burned in at - iteration 6000, we'd want to use 5000 as the burn-in iteration. - However, if the test was 'max_posterior & nacl', we'd want to use - 6000 as the burn-in iteration. This function handles all cases by - doing the following: first, take the collection of burn in iterations - from all the burn in tests that were applied. Next, cycle over the - iterations in increasing order, checking which tests have burned in - by that point. Then evaluate the burn-in string at that point to see - if it passes, and if so, what the iteration is. The first point that - the test passes is used as the burn-in iteration. - - Parameters - ---------- - burn_in_test : str - The test to apply; e.g., ``'max_posterior & nacl'``. - test_is_burned_in : dict - Dictionary of test name -> boolean indicating whether a specific burn - in test has passed. - test_burn_in_iter : dict - Dictionary of test name -> int indicating when a specific test burned - in. - - Returns - ------- - is_burned_in : bool - Whether or not the data passes all burn in tests. - burn_in_iteration : - The iteration at which all the tests pass. If the tests did not all - pass (``is_burned_in`` is false), then returns - :py:data:`NOT_BURNED_IN_ITER`. - """ - burn_in_iters = numpy.unique(list(test_burn_in_iter.values())) - burn_in_iters.sort() - for ii in burn_in_iters: - test_results = {t: (test_is_burned_in[t] & - 0 <= test_burn_in_iter[t] <= ii) - for t in test_is_burned_in} - is_burned_in = eval(burn_in_test, {"__builtins__": None}, - test_results) - if is_burned_in: - break - if not is_burned_in: - ii = NOT_BURNED_IN_ITER - return is_burned_in, ii - - -# -# ============================================================================= -# -# Burn in classes -# -# ============================================================================= -# - - -class BaseBurnInTests(metaclass=ABCMeta): - """Base class for burn in tests.""" - - available_tests = ('halfchain', 'min_iterations', 'max_posterior', - 'posterior_step', 'nacl', - ) - - # pylint: disable=unnecessary-pass - - def __init__(self, sampler, burn_in_test, **kwargs): - self.sampler = sampler - # determine the burn-in tests that are going to be done - self.do_tests = get_vars_from_arg(burn_in_test) - self.burn_in_test = burn_in_test - self.is_burned_in = False - self.burn_in_iteration = NOT_BURNED_IN_ITER - self.test_is_burned_in = {} # burn in status per test - self.test_burn_in_iteration = {} # burn in iter per test - self.test_aux_info = {} # any additional information the test stores - # Arguments specific to each test... - # for nacl: - self._nacls = int(kwargs.pop('nacls', 5)) - # for max_posterior and posterior_step - self._ndim = int(kwargs.pop('ndim', len(sampler.variable_params))) - # for min iterations - self._min_iterations = int(kwargs.pop('min_iterations', 0)) - - @abstractmethod - def burn_in_index(self, filename): - """The burn in index (retrieved from the iteration). - - This is an abstract method because how this is evaluated depends on - if this is an ensemble MCMC or not. - """ - pass - - def _getniters(self, filename): - """Convenience function to get the number of iterations in the file. - - If `niterations` hasn't been written to the file yet, just returns 0. - """ - with self.sampler.io(filename, 'r') as fp: - try: - niters = fp.niterations - except KeyError: - niters = 0 - return niters - - def _getnsamples(self, filename): - """Convenience function to get the number of samples saved in the file. - - If no samples have been written to the file yet, just returns 0. - """ - with self.sampler.io(filename, 'r') as fp: - try: - group = fp[fp.samples_group] - # we'll just use the first parameter - params = list(group.keys()) - nsamples = group[params[0]].shape[-1] - except (KeyError, IndexError): - nsamples = 0 - return nsamples - - def _index2iter(self, filename, index): - """Converts the index in some samples at which burn in occurs to the - iteration of the sampler that corresponds to. - """ - with self.sampler.io(filename, 'r') as fp: - thin_interval = fp.thinned_by - return index * thin_interval - - def _iter2index(self, filename, iteration): - """Converts an iteration to the index it corresponds to. - """ - with self.sampler.io(filename, 'r') as fp: - thin_interval = fp.thinned_by - return iteration // thin_interval - - def _getlogposts(self, filename): - """Convenience function for retrieving log posteriors. - - Parameters - ---------- - filename : str - The file to read. - - Returns - ------- - array - The log posterior values. They are not flattened, so have dimension - nwalkers x niterations. - """ - with self.sampler.io(filename, 'r') as fp: - samples = fp.read_raw_samples( - ['loglikelihood', 'logprior'], thin_start=0, thin_interval=1, - flatten=False) - logposts = samples['loglikelihood'] + samples['logprior'] - return logposts - - def _getacls(self, filename, start_index): - """Convenience function for calculating acls for the given filename. - """ - return self.sampler.compute_acl(filename, start_index=start_index) - - def _getaux(self, test): - """Convenience function for getting auxilary information. - - Parameters - ---------- - test : str - The name of the test to retrieve auxilary information about. - - Returns - ------- - dict - The ``test_aux_info[test]`` dictionary. If a dictionary does - not exist yet for the given test, an empty dictionary will be - created and saved to ``test_aux_info[test]``. - """ - try: - aux = self.test_aux_info[test] - except KeyError: - aux = self.test_aux_info[test] = {} - return aux - - def halfchain(self, filename): - """Just uses half the chain as the burn-in iteration. - """ - niters = self._getniters(filename) - # this test cannot determine when something will burn in - # only when it was not burned in in the past - self.test_is_burned_in['halfchain'] = True - self.test_burn_in_iteration['halfchain'] = niters//2 - - def min_iterations(self, filename): - """Just checks that the sampler has been run for the minimum number - of iterations. - """ - niters = self._getniters(filename) - is_burned_in = self._min_iterations < niters - if is_burned_in: - burn_in_iter = self._min_iterations - else: - burn_in_iter = NOT_BURNED_IN_ITER - self.test_is_burned_in['min_iterations'] = is_burned_in - self.test_burn_in_iteration['min_iterations'] = burn_in_iter - - @abstractmethod - def max_posterior(self, filename): - """Carries out the max posterior test and stores the results.""" - pass - - @abstractmethod - def posterior_step(self, filename): - """Carries out the posterior step test and stores the results.""" - pass - - @abstractmethod - def nacl(self, filename): - """Carries out the nacl test and stores the results.""" - pass - - @abstractmethod - def evaluate(self, filename): - """Performs all tests and evaluates the results to determine if and - when all tests pass. - """ - pass - - def write(self, fp, path=None): - """Writes burn-in info to an open HDF file. - - Parameters - ---------- - fp : pycbc.inference.io.base.BaseInferenceFile - Open HDF file to write the data to. The HDF file should be an - instance of a pycbc BaseInferenceFile. - path : str, optional - Path in the HDF file to write the data to. Default is (None) is - to write to the path given by the file's ``sampler_group`` - attribute. - """ - if path is None: - path = fp.sampler_group - fp.write_data('burn_in_test', self.burn_in_test, path) - fp.write_data('is_burned_in', self.is_burned_in, path) - fp.write_data('burn_in_iteration', self.burn_in_iteration, path) - testgroup = 'burn_in_tests' - # write individual test data - for tst in self.do_tests: - subpath = '/'.join([path, testgroup, tst]) - fp.write_data('is_burned_in', self.test_is_burned_in[tst], subpath) - fp.write_data('burn_in_iteration', - self.test_burn_in_iteration[tst], - subpath) - # write auxiliary info - if tst in self.test_aux_info: - for name, data in self.test_aux_info[tst].items(): - fp.write_data(name, data, subpath) - - @staticmethod - def _extra_tests_from_config(cp, section, tag): - """For loading class-specific tests.""" - # pylint: disable=unused-argument - return {} - - @classmethod - def from_config(cls, cp, sampler): - """Loads burn in from section [sampler-burn_in].""" - section = 'sampler' - tag = 'burn_in' - burn_in_test = cp.get_opt_tag(section, 'burn-in-test', tag) - kwargs = {} - if cp.has_option_tag(section, 'nacl', tag): - kwargs['nacl'] = int(cp.get_opt_tag(section, 'nacl', tag)) - if cp.has_option_tag(section, 'ndim', tag): - kwargs['ndim'] = int( - cp.get_opt_tag(section, 'ndim', tag)) - if cp.has_option_tag(section, 'min-iterations', tag): - kwargs['min_iterations'] = int( - cp.get_opt_tag(section, 'min-iterations', tag)) - # load any class specific tests - kwargs.update(cls._extra_tests_from_config(cp, section, tag)) - return cls(sampler, burn_in_test, **kwargs) - - -class MCMCBurnInTests(BaseBurnInTests): - """Burn-in tests for collections of independent MCMC chains. - - This differs from EnsembleMCMCBurnInTests in that chains are treated as - being independent of each other. The ``is_burned_in`` attribute will be - True if `any` chain passes the burn in tests (whereas in MCMCBurnInTests, - all chains must pass the burn in tests). In other words, independent - samples can be collected even if all of the chains are not burned in. - """ - def __init__(self, sampler, burn_in_test, **kwargs): - super(MCMCBurnInTests, self).__init__(sampler, burn_in_test, **kwargs) - try: - nchains = sampler.nchains - except AttributeError: - nchains = sampler.nwalkers - self.nchains = nchains - self.is_burned_in = numpy.zeros(self.nchains, dtype=bool) - self.burn_in_iteration = numpy.repeat(NOT_BURNED_IN_ITER, self.nchains) - - def burn_in_index(self, filename): - """The burn in index (retrieved from the iteration).""" - burn_in_index = self._iter2index(filename, self.burn_in_iteration) - # don't set if it isn't burned in - burn_in_index[~self.is_burned_in] = NOT_BURNED_IN_ITER - return burn_in_index - - def max_posterior(self, filename): - """Applies max posterior test.""" - logposts = self._getlogposts(filename) - burn_in_idx, is_burned_in = max_posterior(logposts, self._ndim) - # convert index to iterations - burn_in_iter = self._index2iter(filename, burn_in_idx) - burn_in_iter[~is_burned_in] = NOT_BURNED_IN_ITER - # save - test = 'max_posterior' - self.test_is_burned_in[test] = is_burned_in - self.test_burn_in_iteration[test] = burn_in_iter - - def posterior_step(self, filename): - """Applies the posterior-step test.""" - logposts = self._getlogposts(filename) - burn_in_idx = numpy.array([posterior_step(logps, self._ndim) - for logps in logposts]) - # this test cannot determine when something will burn in - # only when it was not burned in in the past - test = 'posterior_step' - if test not in self.test_is_burned_in: - self.test_is_burned_in[test] = numpy.ones(self.nchains, dtype=bool) - # convert index to iterations - self.test_burn_in_iteration[test] = self._index2iter(filename, - burn_in_idx) - - def nacl(self, filename): - """Applies the :py:func:`nacl` test.""" - nsamples = self._getnsamples(filename) - acls = self._getacls(filename, start_index=nsamples//2) - is_burned_in = nacl(nsamples, acls, self._nacls) - # stack the burn in results into an nparams x nchains array - burn_in_per_chain = numpy.stack(list(is_burned_in.values())).all( - axis=0) - # store - test = 'nacl' - self.test_is_burned_in[test] = burn_in_per_chain - try: - burn_in_iter = self.test_burn_in_iteration[test] - except KeyError: - # hasn't been stored yet - burn_in_iter = numpy.repeat(NOT_BURNED_IN_ITER, self.nchains) - self.test_burn_in_iteration[test] = burn_in_iter - burn_in_iter[burn_in_per_chain] = self._index2iter(filename, - nsamples//2) - # add the status for each parameter as additional information - self.test_aux_info[test] = is_burned_in - - def evaluate(self, filename): - """Runs all of the burn-in tests.""" - # evaluate all the tests - for tst in self.do_tests: - logging.info("Evaluating %s burn-in test", tst) - getattr(self, tst)(filename) - # evaluate each chain at a time - for ci in range(self.nchains): - # some tests (like halfchain) just store a single bool for all - # chains - tibi = {t: r[ci] if isinstance(r, numpy.ndarray) else r - for t, r in self.test_is_burned_in.items()} - tbi = {t: r[ci] if isinstance(r, numpy.ndarray) else r - for t, r in self.test_burn_in_iteration.items()} - is_burned_in, burn_in_iter = evaluate_tests(self.burn_in_test, - tibi, tbi) - self.is_burned_in[ci] = is_burned_in - self.burn_in_iteration[ci] = burn_in_iter - logging.info("Number of chains burned in: %i of %i", - self.is_burned_in.sum(), self.nchains) - - def write(self, fp, path=None): - """Writes burn-in info to an open HDF file. - - Parameters - ---------- - fp : pycbc.inference.io.base.BaseInferenceFile - Open HDF file to write the data to. The HDF file should be an - instance of a pycbc BaseInferenceFile. - path : str, optional - Path in the HDF file to write the data to. Default is (None) is - to write to the path given by the file's ``sampler_group`` - attribute. - """ - if path is None: - path = fp.sampler_group - super(MCMCBurnInTests, self).write(fp, path) - # add number of chains burned in as additional metadata - fp.write_data('nchains_burned_in', self.is_burned_in.sum(), path) - - -class MultiTemperedMCMCBurnInTests(MCMCBurnInTests): - """Adds support for multiple temperatures to - :py:class:`MCMCBurnInTests`. - """ - - def _getacls(self, filename, start_index): - """Convenience function for calculating acls for the given filename. - - This function is used by the ``n_acl`` burn-in test. That function - expects the returned ``acls`` dict to just report a single ACL for - each parameter. Since multi-tempered samplers return an array of ACLs - for each parameter instead, this takes the max over the array before - returning. - - Since we calculate the acls, this will also store it to the sampler. - - Parameters - ---------- - filename : str - Name of the file to retrieve samples from. - start_index : int - Index to start calculating ACLs. - - Returns - ------- - dict : - Dictionary of parameter names -> array giving ACL for each chain. - """ - acls = super(MultiTemperedMCMCBurnInTests, self)._getacls( - filename, start_index) - # acls will have shape ntemps x nchains, flatten to nchains - return {param: vals.max(axis=0) for (param, vals) in acls.items()} - - def _getlogposts(self, filename): - """Convenience function for retrieving log posteriors. - - This just gets the coldest temperature chain, and returns arrays with - shape nwalkers x niterations, so the parent class can run the same - ``posterior_step`` function. - """ - return _multitemper_getlogposts(self.sampler, filename) - - -class EnsembleMCMCBurnInTests(BaseBurnInTests): - """Provides methods for estimating burn-in of an ensemble MCMC.""" - - available_tests = ('halfchain', 'min_iterations', 'max_posterior', - 'posterior_step', 'nacl', 'ks_test', - ) - - def __init__(self, sampler, burn_in_test, **kwargs): - super(EnsembleMCMCBurnInTests, self).__init__( - sampler, burn_in_test, **kwargs) - # for kstest - self._ksthreshold = float(kwargs.pop('ks_threshold', 0.9)) - - def burn_in_index(self, filename): - """The burn in index (retrieved from the iteration).""" - if self.is_burned_in: - index = self._iter2index(filename, self.burn_in_iteration) - else: - index = NOT_BURNED_IN_ITER - return index - - def max_posterior(self, filename): - """Applies max posterior test to self.""" - logposts = self._getlogposts(filename) - burn_in_idx, is_burned_in = max_posterior(logposts, self._ndim) - all_burned_in = is_burned_in.all() - if all_burned_in: - burn_in_iter = self._index2iter(filename, burn_in_idx.max()) - else: - burn_in_iter = NOT_BURNED_IN_ITER - # store - test = 'max_posterior' - self.test_is_burned_in[test] = all_burned_in - self.test_burn_in_iteration[test] = burn_in_iter - aux = self._getaux(test) - # additional info - aux['iteration_per_walker'] = self._index2iter(filename, burn_in_idx) - aux['status_per_walker'] = is_burned_in - - def posterior_step(self, filename): - """Applies the posterior-step test.""" - logposts = self._getlogposts(filename) - burn_in_idx = numpy.array([posterior_step(logps, self._ndim) - for logps in logposts]) - burn_in_iters = self._index2iter(filename, burn_in_idx) - # this test cannot determine when something will burn in - # only when it was not burned in in the past - test = 'posterior_step' - self.test_is_burned_in[test] = True - self.test_burn_in_iteration[test] = burn_in_iters.max() - # store the iteration per walker as additional info - aux = self._getaux(test) - aux['iteration_per_walker'] = burn_in_iters - - def nacl(self, filename): - """Applies the :py:func:`nacl` test.""" - nsamples = self._getnsamples(filename) - acls = self._getacls(filename, start_index=nsamples//2) - is_burned_in = nacl(nsamples, acls, self._nacls) - all_burned_in = all(is_burned_in.values()) - if all_burned_in: - burn_in_iter = self._index2iter(filename, nsamples//2) - else: - burn_in_iter = NOT_BURNED_IN_ITER - # store - test = 'nacl' - self.test_is_burned_in[test] = all_burned_in - self.test_burn_in_iteration[test] = burn_in_iter - # store the status per parameter as additional info - aux = self._getaux(test) - aux['status_per_parameter'] = is_burned_in - - def ks_test(self, filename): - """Applies ks burn-in test.""" - nsamples = self._getnsamples(filename) - with self.sampler.io(filename, 'r') as fp: - # get the samples from the mid point - samples1 = fp.read_raw_samples( - ['loglikelihood', 'logprior'], iteration=int(nsamples/2.)) - # get the last samples - samples2 = fp.read_raw_samples( - ['loglikelihood', 'logprior'], iteration=-1) - # do the test - # is_the_same is a dictionary of params --> bool indicating whether or - # not the 1D marginal is the same at the half way point - is_the_same = ks_test(samples1, samples2, threshold=self._ksthreshold) - is_burned_in = all(is_the_same.values()) - if is_burned_in: - burn_in_iter = self._index2iter(filename, int(nsamples//2)) - else: - burn_in_iter = NOT_BURNED_IN_ITER - # store - test = 'ks_test' - self.test_is_burned_in[test] = is_burned_in - self.test_burn_in_iteration[test] = burn_in_iter - # store the test per parameter as additional info - aux = self._getaux(test) - aux['status_per_parameter'] = is_the_same - - def evaluate(self, filename): - """Runs all of the burn-in tests.""" - # evaluate all the tests - for tst in self.do_tests: - logging.info("Evaluating %s burn-in test", tst) - getattr(self, tst)(filename) - is_burned_in, burn_in_iter = evaluate_tests( - self.burn_in_test, self.test_is_burned_in, - self.test_burn_in_iteration) - self.is_burned_in = is_burned_in - self.burn_in_iteration = burn_in_iter - logging.info("Is burned in: %r", self.is_burned_in) - if self.is_burned_in: - logging.info("Burn-in iteration: %i", - int(self.burn_in_iteration)) - - @staticmethod - def _extra_tests_from_config(cp, section, tag): - """Loads the ks test settings from the config file.""" - kwargs = {} - if cp.has_option_tag(section, 'ks-threshold', tag): - kwargs['ks_threshold'] = float( - cp.get_opt_tag(section, 'ks-threshold', tag)) - return kwargs - - -class EnsembleMultiTemperedMCMCBurnInTests(EnsembleMCMCBurnInTests): - """Adds support for multiple temperatures to - :py:class:`EnsembleMCMCBurnInTests`. - """ - - def _getacls(self, filename, start_index): - """Convenience function for calculating acls for the given filename. - - This function is used by the ``n_acl`` burn-in test. That function - expects the returned ``acls`` dict to just report a single ACL for - each parameter. Since multi-tempered samplers return an array of ACLs - for each parameter instead, this takes the max over the array before - returning. - - Since we calculate the acls, this will also store it to the sampler. - """ - acls = super(EnsembleMultiTemperedMCMCBurnInTests, self)._getacls( - filename, start_index) - # return the max for each parameter - return {param: vals.max() for (param, vals) in acls.items()} - - def _getlogposts(self, filename): - """Convenience function for retrieving log posteriors. - - This just gets the coldest temperature chain, and returns arrays with - shape nwalkers x niterations, so the parent class can run the same - ``posterior_step`` function. - """ - return _multitemper_getlogposts(self.sampler, filename) - - -def _multitemper_getlogposts(sampler, filename): - """Retrieve log posteriors for multi tempered samplers.""" - with sampler.io(filename, 'r') as fp: - samples = fp.read_raw_samples( - ['loglikelihood', 'logprior'], thin_start=0, thin_interval=1, - temps=0, flatten=False) - # reshape to drop the first dimension - for (stat, arr) in samples.items(): - _, nwalkers, niterations = arr.shape - samples[stat] = arr.reshape((nwalkers, niterations)) - logposts = samples['loglikelihood'] + samples['logprior'] - return logposts diff --git a/pycbc/inference/.ipynb_checkpoints/entropy-checkpoint.py b/pycbc/inference/.ipynb_checkpoints/entropy-checkpoint.py deleted file mode 100644 index 9f71183fd55..00000000000 --- a/pycbc/inference/.ipynb_checkpoints/entropy-checkpoint.py +++ /dev/null @@ -1,242 +0,0 @@ -""" The module contains functions for calculating the -Kullback-Leibler divergence. -""" - -import numpy -from scipy import stats - - -def check_hist_params(samples, hist_min, hist_max, hist_bins): - """ Checks that the bound values given for the histogram are consistent, - returning the range if they are or raising an error if they are not. - Also checks that if hist_bins is a str, it corresponds to a method - available in numpy.histogram - - Parameters - ---------- - samples : numpy.array - Set of samples to get the min/max if only one of the bounds is given. - hist_min : numpy.float64 - Minimum value for the histogram. - hist_max : numpy.float64 - Maximum value for the histogram. - hist_bins: int or str - If int, number of equal-width bins to use in numpy.histogram. If str, - it should be one of the methods to calculate the optimal bin width - available in numpy.histogram: ['auto', 'fd', 'doane', 'scott', 'stone', - 'rice', 'sturges', 'sqrt']. Default is 'fd' (Freedman Diaconis - Estimator). This option will be ignored if `kde=True`. - - Returns - ------- - hist_range : tuple or None - The bounds (hist_min, hist_max) or None. - hist_bins : int or str - Number of bins or method for optimal width bin calculation. - """ - - hist_methods = ['auto', 'fd', 'doane', 'scott', 'stone', 'rice', - 'sturges', 'sqrt'] - if not hist_bins: - hist_bins = 'fd' - elif isinstance(hist_bins, str) and hist_bins not in hist_methods: - raise ValueError('Method for calculating bins width must be one of' - ' {}'.format(hist_methods)) - - # No bounds given, return None - if not hist_min and not hist_max: - return None, hist_bins - - # One of the bounds is missing - if hist_min and not hist_max: - hist_max = samples.max() - elif hist_max and not hist_min: - hist_min = samples.min() - # Both bounds given - elif hist_min and hist_max and hist_min >= hist_max: - raise ValueError('hist_min must be lower than hist_max.') - - hist_range = (hist_min, hist_max) - - return hist_range, hist_bins - - -def compute_pdf(samples, method, bins, hist_min, hist_max): - """ Computes the probability density function for a set of samples. - - Parameters - ---------- - samples : numpy.array - Set of samples to calculate the pdf. - method : str - Method to calculate the pdf. Options are 'kde' for the Kernel Density - Estimator, and 'hist' to use numpy.histogram - bins : str or int, optional - This option will be ignored if method is `kde`. - If int, number of equal-width bins to use when calculating probability - density function from a set of samples of the distribution. If str, it - should be one of the methods to calculate the optimal bin width - available in numpy.histogram: ['auto', 'fd', 'doane', 'scott', 'stone', - 'rice', 'sturges', 'sqrt']. Default is 'fd' (Freedman Diaconis - Estimator). - hist_min : numpy.float64, optional - Minimum of the distributions' values to use. This will be ignored if - `kde=True`. - hist_max : numpy.float64, optional - Maximum of the distributions' values to use. This will be ignored if - `kde=True`. - - Returns - ------- - pdf : numpy.array - Discrete probability distribution calculated from samples. - """ - - if method == 'kde': - samples_kde = stats.gaussian_kde(samples) - npts = 10000 if len(samples) <= 10000 else len(samples) - draw = samples_kde.resample(npts) - pdf = samples_kde.evaluate(draw) - elif method == 'hist': - hist_range, hist_bins = check_hist_params(samples, hist_min, - hist_max, bins) - pdf, _ = numpy.histogram(samples, bins=hist_bins, - range=hist_range, density=True) - else: - raise ValueError('Method not recognized.') - - return pdf - - -def entropy(pdf1, base=numpy.e): - """ Computes the information entropy for a single parameter - from one probability density function. - - Parameters - ---------- - pdf1 : numpy.array - Probability density function. - base : {numpy.e, numpy.float64}, optional - The logarithmic base to use (choose base 2 for information measured - in bits, default is nats). - - Returns - ------- - numpy.float64 - The information entropy value. - """ - - return stats.entropy(pdf1, base=base) - - -def kl(samples1, samples2, pdf1=False, pdf2=False, kde=False, - bins=None, hist_min=None, hist_max=None, base=numpy.e): - """ Computes the Kullback-Leibler divergence for a single parameter - from two distributions. - - Parameters - ---------- - samples1 : numpy.array - Samples or probability density function (for the latter must also set - `pdf1=True`). - samples2 : numpy.array - Samples or probability density function (for the latter must also set - `pdf2=True`). - pdf1 : bool - Set to `True` if `samples1` is a probability density funtion already. - pdf2 : bool - Set to `True` if `samples2` is a probability density funtion already. - kde : bool - Set to `True` if at least one of `pdf1` or `pdf2` is `False` to - estimate the probability density function using kernel density - estimation (KDE). - bins : int or str, optional - If int, number of equal-width bins to use when calculating probability - density function from a set of samples of the distribution. If str, it - should be one of the methods to calculate the optimal bin width - available in numpy.histogram: ['auto', 'fd', 'doane', 'scott', 'stone', - 'rice', 'sturges', 'sqrt']. Default is 'fd' (Freedman Diaconis - Estimator). This option will be ignored if `kde=True`. - hist_min : numpy.float64 - Minimum of the distributions' values to use. This will be ignored if - `kde=True`. - hist_max : numpy.float64 - Maximum of the distributions' values to use. This will be ignored if - `kde=True`. - base : numpy.float64 - The logarithmic base to use (choose base 2 for information measured - in bits, default is nats). - - Returns - ------- - numpy.float64 - The Kullback-Leibler divergence value. - """ - if pdf1 and pdf2 and kde: - raise ValueError('KDE can only be used when at least one of pdf1 or ' - 'pdf2 is False.') - - sample_groups = {'P': (samples1, pdf1), 'Q': (samples2, pdf2)} - pdfs = {} - for n in sample_groups: - samples, pdf = sample_groups[n] - if pdf: - pdfs[n] = samples - else: - method = 'kde' if kde else 'hist' - pdfs[n] = compute_pdf(samples, method, bins, hist_min, hist_max) - - return stats.entropy(pdfs['P'], qk=pdfs['Q'], base=base) - - -def js(samples1, samples2, kde=False, bins=None, hist_min=None, hist_max=None, - base=numpy.e): - """ Computes the Jensen-Shannon divergence for a single parameter - from two distributions. - - Parameters - ---------- - samples1 : numpy.array - Samples. - samples2 : numpy.array - Samples. - kde : bool - Set to `True` to estimate the probability density function using - kernel density estimation (KDE). - bins : int or str, optional - If int, number of equal-width bins to use when calculating probability - density function from a set of samples of the distribution. If str, it - should be one of the methods to calculate the optimal bin width - available in numpy.histogram: ['auto', 'fd', 'doane', 'scott', 'stone', - 'rice', 'sturges', 'sqrt']. Default is 'fd' (Freedman Diaconis - Estimator). This option will be ignored if `kde=True`. - hist_min : numpy.float64 - Minimum of the distributions' values to use. This will be ignored if - `kde=True`. - hist_max : numpy.float64 - Maximum of the distributions' values to use. This will be ignored if - `kde=True`. - base : numpy.float64 - The logarithmic base to use (choose base 2 for information measured - in bits, default is nats). - - Returns - ------- - numpy.float64 - The Jensen-Shannon divergence value. - """ - - sample_groups = {'P': samples1, 'Q': samples2} - pdfs = {} - for n in sample_groups: - samples = sample_groups[n] - method = 'kde' if kde else 'hist' - pdfs[n] = compute_pdf(samples, method, bins, hist_min, hist_max) - - pdfs['M'] = (1./2) * (pdfs['P'] + pdfs['Q']) - - js_div = 0 - for pdf in (pdfs['P'], pdfs['Q']): - js_div += (1./2) * kl(pdf, pdfs['M'], pdf1=True, pdf2=True, base=base) - - return js_div diff --git a/pycbc/inference/models/.ipynb_checkpoints/__init__-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/__init__-checkpoint.py deleted file mode 100644 index ed69894124e..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/__init__-checkpoint.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (C) 2018 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -""" -This package provides classes and functions for evaluating Bayesian statistics -assuming various noise models. -""" - - -import logging -from pkg_resources import iter_entry_points as _iter_entry_points -from .base import BaseModel -from .base_data import BaseDataModel -from .analytic import (TestEggbox, TestNormal, TestRosenbrock, TestVolcano, - TestPrior, TestPosterior) -from .gaussian_noise import GaussianNoise -from .marginalized_gaussian_noise import MarginalizedPhaseGaussianNoise -from .marginalized_gaussian_noise import MarginalizedPolarization -from .marginalized_gaussian_noise import MarginalizedHMPolPhase -from .brute_marg import BruteParallelGaussianMarginalize -from .gated_gaussian_noise import (GatedGaussianNoise, GatedGaussianMargPol) -from .single_template import SingleTemplate -from .relbin import Relative -from .hierarchical import HierarchicalModel, MultiSignalModel - - -# Used to manage a model instance across multiple cores or MPI -_global_instance = None - - -def _call_global_model(*args, **kwds): - """Private function for global model (needed for parallelization).""" - return _global_instance(*args, **kwds) # pylint:disable=not-callable - - -def _call_global_model_logprior(*args, **kwds): - """Private function for a calling global's logprior. - - This is needed for samplers that use a separate function for the logprior, - like ``emcee_pt``. - """ - # pylint:disable=not-callable - return _global_instance(*args, callstat='logprior', **kwds) - - -class CallModel(object): - """Wrapper class for calling models from a sampler. - - This class can be called like a function, with the parameter values to - evaluate provided as a list in the same order as the model's - ``variable_params``. In that case, the model is updated with the provided - parameters and then the ``callstat`` retrieved. If ``return_all_stats`` is - set to ``True``, then all of the stats specified by the model's - ``default_stats`` will be returned as a tuple, in addition to the stat - value. - - The model's attributes are promoted to this class's namespace, so that any - attribute and method of ``model`` may be called directly from this class. - - This class must be initalized prior to the creation of a ``Pool`` object. - - Parameters - ---------- - model : Model instance - The model to call. - callstat : str - The statistic to call. - return_all_stats : bool, optional - Whether or not to return all of the other statistics along with the - ``callstat`` value. - - Examples - -------- - Create a wrapper around an instance of the ``TestNormal`` model, with the - ``callstat`` set to ``logposterior``: - - >>> from pycbc.inference.models import TestNormal, CallModel - >>> model = TestNormal(['x', 'y']) - >>> call_model = CallModel(model, 'logposterior') - - Now call on a set of parameter values: - - >>> call_model([0.1, -0.2]) - (-1.8628770664093453, (0.0, 0.0, -1.8628770664093453)) - - Note that a tuple of all of the model's ``default_stats`` were returned in - addition to the ``logposterior`` value. We can shut this off by toggling - ``return_all_stats``: - - >>> call_model.return_all_stats = False - >>> call_model([0.1, -0.2]) - -1.8628770664093453 - - Attributes of the model can be called from the call model. For example: - - >>> call_model.variable_params - ('x', 'y') - - """ - - def __init__(self, model, callstat, return_all_stats=True): - self.model = model - self.callstat = callstat - self.return_all_stats = return_all_stats - - def __getattr__(self, attr): - """Adds the models attributes to self.""" - return getattr(self.model, attr) - - def __call__(self, param_values, callstat=None, return_all_stats=None): - """Updates the model with the given parameter values, then calls the - call function. - - Parameters - ---------- - param_values : list of float - The parameter values to test. Assumed to be in the same order as - ``model.sampling_params``. - callstat : str, optional - Specify which statistic to call. Default is to call whatever self's - ``callstat`` is set to. - return_all_stats : bool, optional - Whether or not to return all stats in addition to the ``callstat`` - value. Default is to use self's ``return_all_stats``. - - Returns - ------- - stat : float - The statistic returned by the ``callfunction``. - all_stats : tuple, optional - The values of all of the model's ``default_stats`` at the given - param values. Any stat that has not be calculated is set to - ``numpy.nan``. This is only returned if ``return_all_stats`` is - set to ``True``. - """ - if callstat is None: - callstat = self.callstat - if return_all_stats is None: - return_all_stats = self.return_all_stats - params = dict(zip(self.model.sampling_params, param_values)) - self.model.update(**params) - val = getattr(self.model, callstat) - if return_all_stats: - return val, self.model.get_current_stats() - else: - return val - - -def read_from_config(cp, **kwargs): - """Initializes a model from the given config file. - - The section must have a ``name`` argument. The name argument corresponds to - the name of the class to initialize. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - \**kwargs : - All other keyword arguments are passed to the ``from_config`` method - of the class specified by the name argument. - - Returns - ------- - cls - The initialized model. - """ - # use the name to get the distribution - name = cp.get("model", "name") - return get_model(name).from_config(cp, **kwargs) - - -_models = {_cls.name: _cls for _cls in ( - TestEggbox, - TestNormal, - TestRosenbrock, - TestVolcano, - TestPosterior, - TestPrior, - GaussianNoise, - MarginalizedPhaseGaussianNoise, - MarginalizedPolarization, - MarginalizedHMPolPhase, - BruteParallelGaussianMarginalize, - GatedGaussianNoise, - GatedGaussianMargPol, - SingleTemplate, - Relative, - HierarchicalModel, - MultiSignalModel, -)} - - -class _ModelManager(dict): - """Sub-classes dictionary to manage the collection of available models. - - The first time this is called, any plugin models that are available will be - added to the dictionary before returning. - """ - def __init__(self, *args, **kwargs): - self.retrieve_plugins = True - super().__init__(*args, **kwargs) - - def add_model(self, model): - """Adds a model to the dictionary. - - If the given model has the same name as a model already in the - dictionary, the original model will be overridden. A warning will be - printed in that case. - """ - if super().__contains__(model.name): - logging.warning("Custom model %s will override a model of the " - "same name. If you don't want this, change the " - "model's name attribute and restart.", model.name) - self[model.name] = model - - def add_plugins(self): - """Adds any plugin models that are available. - - This will only add the plugins if ``self.retrieve_plugins = True``. - After this runs, ``self.retrieve_plugins`` is set to ``False``, so that - subsequent calls to this will no re-add models. - """ - if self.retrieve_plugins: - for plugin in _iter_entry_points('pycbc.inference.models'): - self.add_model(plugin.resolve()) - self.retrieve_plugins = False - - def __len__(self): - self.add_plugins() - super().__len__() - - def __contains__(self, key): - self.add_plugins() - return super().__contains__(key) - - def get(self, *args): - self.add_plugins() - return super().get(*args) - - def popitem(self): - self.add_plugins() - return super().popitem() - - def pop(self, *args): - try: - return super().pop(*args) - except KeyError: - self.add_plugins() - return super().pop(*args) - - def keys(self): - self.add_plugins() - return super().keys() - - def values(self): - self.add_plugins() - return super().values() - - def items(self): - self.add_plugins() - return super().items() - - def __iter__(self): - self.add_plugins() - return super().__iter__() - - def __repr__(self): - self.add_plugins() - return super().__repr__() - - def __getitem__(self, item): - try: - return super().__getitem__(item) - except KeyError: - self.add_plugins() - return super().__getitem__(item) - - def __delitem__(self, *args, **kwargs): - try: - super().__delitem__(*args, **kwargs) - except KeyError: - self.add_plugins() - super().__delitem__(*args, **kwargs) - - -models = _ModelManager(_models) - - -def get_models(): - """Returns the dictionary of current models. - - Ensures that plugins are added to the dictionary first. - """ - models.add_plugins() - return models - - -def get_model(model_name): - """Retrieve the given model. - - Parameters - ---------- - model_name : str - The name of the model to get. - - Returns - ------- - model : - The requested model. - """ - return get_models()[model_name] - - -def available_models(): - """List the currently available models.""" - return list(get_models().keys()) - - -def register_model(model): - """Makes a custom model available to PyCBC. - - The provided model will be added to the dictionary of models that PyCBC - knows about, using the model's ``name`` attribute. If the ``name`` is the - same as a model that already exists in PyCBC, a warning will be printed. - - Parameters - ---------- - model : pycbc.inference.models.base.BaseModel - The model to use. The model should be a sub-class of - :py:class:`BaseModel ` to ensure - it has the correct API for use within ``pycbc_inference``. - """ - get_models().add_model(model) diff --git a/pycbc/inference/models/.ipynb_checkpoints/base-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/base-checkpoint.py deleted file mode 100644 index e8e0eeb9d22..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/base-checkpoint.py +++ /dev/null @@ -1,886 +0,0 @@ -# Copyright (C) 2016 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# - -"""Base class for models. -""" - -import numpy -import logging -from abc import (ABCMeta, abstractmethod) -from configparser import NoSectionError -from pycbc import (transforms, distributions) -from pycbc.io import FieldArray - - -# -# ============================================================================= -# -# Support classes -# -# ============================================================================= -# - - -class _NoPrior(object): - """Dummy class to just return 0 if no prior is given to a model. - """ - @staticmethod - def apply_boundary_conditions(**params): - return params - - def __call__(self, **params): - return 0. - - -class ModelStats(object): - """Class to hold model's current stat values.""" - - @property - def statnames(self): - """Returns the names of the stats that have been stored.""" - return list(self.__dict__.keys()) - - def getstats(self, names, default=numpy.nan): - """Get the requested stats as a tuple. - - If a requested stat is not an attribute (implying it hasn't been - stored), then the default value is returned for that stat. - - Parameters - ---------- - names : list of str - The names of the stats to get. - default : float, optional - What to return if a requested stat is not an attribute of self. - Default is ``numpy.nan``. - - Returns - ------- - tuple - A tuple of the requested stats. - """ - return tuple(getattr(self, n, default) for n in names) - - def getstatsdict(self, names, default=numpy.nan): - """Get the requested stats as a dictionary. - - If a requested stat is not an attribute (implying it hasn't been - stored), then the default value is returned for that stat. - - Parameters - ---------- - names : list of str - The names of the stats to get. - default : float, optional - What to return if a requested stat is not an attribute of self. - Default is ``numpy.nan``. - - Returns - ------- - dict - A dictionary of the requested stats. - """ - return dict(zip(names, self.getstats(names, default=default))) - - -class SamplingTransforms(object): - """Provides methods for transforming between sampling parameter space and - model parameter space. - """ - - def __init__(self, variable_params, sampling_params, - replace_parameters, sampling_transforms): - assert len(replace_parameters) == len(sampling_params), ( - "number of sampling parameters must be the " - "same as the number of replace parameters") - # pull out the replaced parameters - self.sampling_params = [arg for arg in variable_params - if arg not in replace_parameters] - # add the sampling parameters - self.sampling_params += sampling_params - # sort to make sure we have a consistent order - self.sampling_params.sort() - self.sampling_transforms = sampling_transforms - - def logjacobian(self, **params): - r"""Returns the log of the jacobian needed to transform pdfs in the - ``variable_params`` parameter space to the ``sampling_params`` - parameter space. - - Let :math:`\mathbf{x}` be the set of variable parameters, - :math:`\mathbf{y} = f(\mathbf{x})` the set of sampling parameters, and - :math:`p_x(\mathbf{x})` a probability density function defined over - :math:`\mathbf{x}`. - The corresponding pdf in :math:`\mathbf{y}` is then: - - .. math:: - - p_y(\mathbf{y}) = - p_x(\mathbf{x})\left|\mathrm{det}\,\mathbf{J}_{ij}\right|, - - where :math:`\mathbf{J}_{ij}` is the Jacobian of the inverse transform - :math:`\mathbf{x} = g(\mathbf{y})`. This has elements: - - .. math:: - - \mathbf{J}_{ij} = \frac{\partial g_i}{\partial{y_j}} - - This function returns - :math:`\log \left|\mathrm{det}\,\mathbf{J}_{ij}\right|`. - - - Parameters - ---------- - \**params : - The keyword arguments should specify values for all of the variable - args and all of the sampling args. - - Returns - ------- - float : - The value of the jacobian. - """ - return numpy.log(abs(transforms.compute_jacobian( - params, self.sampling_transforms, inverse=True))) - - def apply(self, samples, inverse=False): - """Applies the sampling transforms to the given samples. - - Parameters - ---------- - samples : dict or FieldArray - The samples to apply the transforms to. - inverse : bool, optional - Whether to apply the inverse transforms (i.e., go from the sampling - args to the ``variable_params``). Default is False. - - Returns - ------- - dict or FieldArray - The transformed samples, along with the original samples. - """ - return transforms.apply_transforms(samples, self.sampling_transforms, - inverse=inverse) - - @classmethod - def from_config(cls, cp, variable_params): - """Gets sampling transforms specified in a config file. - - Sampling parameters and the parameters they replace are read from the - ``sampling_params`` section, if it exists. Sampling transforms are - read from the ``sampling_transforms`` section(s), using - ``transforms.read_transforms_from_config``. - - An ``AssertionError`` is raised if no ``sampling_params`` section - exists in the config file. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - variable_params : list - List of parameter names of the original variable params. - - Returns - ------- - SamplingTransforms - A sampling transforms class. - """ - # Check if a sampling_params section is provided - try: - sampling_params, replace_parameters = \ - read_sampling_params_from_config(cp) - except NoSectionError as e: - logging.warning("No sampling_params section read from config file") - raise e - # get sampling transformations - sampling_transforms = transforms.read_transforms_from_config( - cp, 'sampling_transforms') - logging.info("Sampling in {} in place of {}".format( - ', '.join(sampling_params), ', '.join(replace_parameters))) - return cls(variable_params, sampling_params, - replace_parameters, sampling_transforms) - - -def read_sampling_params_from_config(cp, section_group=None, - section='sampling_params'): - """Reads sampling parameters from the given config file. - - Parameters are read from the `[({section_group}_){section}]` section. - The options should list the variable args to transform; the parameters they - point to should list the parameters they are to be transformed to for - sampling. If a multiple parameters are transformed together, they should - be comma separated. Example: - - .. code-block:: ini - - [sampling_params] - mass1, mass2 = mchirp, logitq - spin1_a = logitspin1_a - - Note that only the final sampling parameters should be listed, even if - multiple intermediate transforms are needed. (In the above example, a - transform is needed to go from mass1, mass2 to mchirp, q, then another one - needed to go from q to logitq.) These transforms should be specified - in separate sections; see ``transforms.read_transforms_from_config`` for - details. - - Parameters - ---------- - cp : WorkflowConfigParser - An open config parser to read from. - section_group : str, optional - Append `{section_group}_` to the section name. Default is None. - section : str, optional - The name of the section. Default is 'sampling_params'. - - Returns - ------- - sampling_params : list - The list of sampling parameters to use instead. - replaced_params : list - The list of variable args to replace in the sampler. - """ - if section_group is not None: - section_prefix = '{}_'.format(section_group) - else: - section_prefix = '' - section = section_prefix + section - replaced_params = set() - sampling_params = set() - for args in cp.options(section): - map_args = cp.get(section, args) - sampling_params.update(set(map(str.strip, map_args.split(',')))) - replaced_params.update(set(map(str.strip, args.split(',')))) - return sorted(sampling_params), sorted(replaced_params) - - -# -# ============================================================================= -# -# Base model definition -# -# ============================================================================= -# - - -class BaseModel(metaclass=ABCMeta): - r"""Base class for all models. - - Given some model :math:`h` with parameters :math:`\Theta`, Bayes Theorem - states that the probability of observing parameter values :math:`\vartheta` - is: - - .. math:: - - p(\vartheta|h) = \frac{p(h|\vartheta) p(\vartheta)}{p(h)}. - - Here: - - * :math:`p(\vartheta|h)` is the **posterior** probability; - - * :math:`p(h|\vartheta)` is the **likelihood**; - - * :math:`p(\vartheta)` is the **prior**; - - * :math:`p(h)` is the **evidence**. - - This class defines properties and methods for evaluating the log - likelihood, log prior, and log posteror. A set of parameter values is set - using the ``update`` method. Calling the class's - ``log(likelihood|prior|posterior)`` properties will then evaluate the model - at those parameter values. - - Classes that inherit from this class must implement a ``_loglikelihood`` - function that can be called by ``loglikelihood``. - - Parameters - ---------- - variable_params : (tuple of) string(s) - A tuple of parameter names that will be varied. - static_params : dict, optional - A dictionary of parameter names -> values to keep fixed. - prior : callable, optional - A callable class or function that computes the log of the prior. If - None provided, will use ``_noprior``, which returns 0 for all parameter - values. - sampling_params : list, optional - Replace one or more of the ``variable_params`` with the given - parameters for sampling. - replace_parameters : list, optional - The ``variable_params`` to replace with sampling parameters. Must be - the same length as ``sampling_params``. - sampling_transforms : list, optional - List of transforms to use to go between the ``variable_params`` and the - sampling parameters. Required if ``sampling_params`` is not None. - waveform_transforms : list, optional - A list of transforms to convert the ``variable_params`` into something - understood by the likelihood model. This is useful if the prior is - more easily parameterized in parameters that are different than what - the likelihood is most easily defined in. Since these are used solely - for converting parameters, and not for rescaling the parameter space, - a Jacobian is not required for these transforms. - """ - name = None - - def __init__(self, variable_params, static_params=None, prior=None, - sampling_transforms=None, waveform_transforms=None): - # store variable and static args - self.variable_params = variable_params - self.static_params = static_params - # store prior - if prior is None: - self.prior_distribution = _NoPrior() - elif set(prior.variable_args) != set(variable_params): - raise ValueError("variable params of prior and model must be the " - "same") - else: - self.prior_distribution = prior - # store transforms - self.sampling_transforms = sampling_transforms - self.waveform_transforms = waveform_transforms - # initialize current params to None - self._current_params = None - # initialize a model stats - self._current_stats = ModelStats() - - @property - def variable_params(self): - """Returns the model parameters.""" - return self._variable_params - - @variable_params.setter - def variable_params(self, variable_params): - if isinstance(variable_params, str): - variable_params = (variable_params,) - if not isinstance(variable_params, tuple): - variable_params = tuple(variable_params) - self._variable_params = variable_params - - @property - def static_params(self): - """Returns the model's static arguments.""" - return self._static_params - - @static_params.setter - def static_params(self, static_params): - if static_params is None: - static_params = {} - self._static_params = static_params - - @property - def sampling_params(self): - """Returns the sampling parameters. - - If ``sampling_transforms`` is None, this is the same as the - ``variable_params``. - """ - if self.sampling_transforms is None: - sampling_params = self.variable_params - else: - sampling_params = self.sampling_transforms.sampling_params - return sampling_params - - def update(self, **params): - """Updates the current parameter positions and resets stats. - - If any sampling transforms are specified, they are applied to the - params before being stored. - """ - # add the static params - params.update(self.static_params) - self._current_params = self._transform_params(**params) - self._current_stats = ModelStats() - - @property - def current_params(self): - if self._current_params is None: - raise ValueError("no parameters values currently stored; " - "run update to add some") - return self._current_params - - @property - def default_stats(self): - """The stats that ``get_current_stats`` returns by default.""" - return ['logjacobian', 'logprior', 'loglikelihood'] + self._extra_stats - - @property - def _extra_stats(self): - """Allows child classes to add more stats to the default stats. - - This returns an empty list; classes that inherit should override this - property if they want to add extra stats. - """ - return [] - - def get_current_stats(self, names=None): - """Return one or more of the current stats as a tuple. - - This function does no computation. It only returns what has already - been calculated. If a stat hasn't been calculated, it will be returned - as ``numpy.nan``. - - Parameters - ---------- - names : list of str, optional - Specify the names of the stats to retrieve. If ``None`` (the - default), will return ``default_stats``. - - Returns - ------- - tuple : - The current values of the requested stats, as a tuple. The order - of the stats is the same as the names. - """ - if names is None: - names = self.default_stats - return self._current_stats.getstats(names) - - @property - def current_stats(self): - """Return the ``default_stats`` as a dict. - - This does no computation. It only returns what has already been - calculated. If a stat hasn't been calculated, it will be returned - as ``numpy.nan``. - - Returns - ------- - dict : - Dictionary of stat names -> current stat values. - """ - return self._current_stats.getstatsdict(self.default_stats) - - def _trytoget(self, statname, fallback, apply_transforms=False, **kwargs): - r"""Helper function to get a stat from ``_current_stats``. - - If the statistic hasn't been calculated, ``_current_stats`` will raise - an ``AttributeError``. In that case, the ``fallback`` function will - be called. If that call is successful, the ``statname`` will be added - to ``_current_stats`` with the returned value. - - Parameters - ---------- - statname : str - The stat to get from ``current_stats``. - fallback : method of self - The function to call if the property call fails. - apply_transforms : bool, optional - Apply waveform transforms to the current parameters before calling - the fallback function. Default is False. - \**kwargs : - Any other keyword arguments are passed through to the function. - - Returns - ------- - float : - The value of the property. - """ - try: - return getattr(self._current_stats, statname) - except AttributeError: - # apply waveform transforms if requested - if apply_transforms and self.waveform_transforms is not None: - self._current_params = transforms.apply_transforms( - self._current_params, self.waveform_transforms, - inverse=False) - val = fallback(**kwargs) - setattr(self._current_stats, statname, val) - return val - - @property - def loglikelihood(self): - """The log likelihood at the current parameters. - - This will initially try to return the ``current_stats.loglikelihood``. - If that raises an ``AttributeError``, will call `_loglikelihood`` to - calculate it and store it to ``current_stats``. - """ - return self._trytoget('loglikelihood', self._loglikelihood, - apply_transforms=True) - - @abstractmethod - def _loglikelihood(self): - """Low-level function that calculates the log likelihood of the current - params.""" - pass - - @property - def logjacobian(self): - """The log jacobian of the sampling transforms at the current postion. - - If no sampling transforms were provided, will just return 0. - - Parameters - ---------- - \**params : - The keyword arguments should specify values for all of the variable - args and all of the sampling args. - - Returns - ------- - float : - The value of the jacobian. - """ - return self._trytoget('logjacobian', self._logjacobian) - - def _logjacobian(self): - """Calculates the logjacobian of the current parameters.""" - if self.sampling_transforms is None: - logj = 0. - else: - logj = self.sampling_transforms.logjacobian( - **self.current_params) - return logj - - @property - def logprior(self): - """Returns the log prior at the current parameters.""" - return self._trytoget('logprior', self._logprior) - - def _logprior(self): - """Calculates the log prior at the current parameters.""" - logj = self.logjacobian - logp = self.prior_distribution(**self.current_params) + logj - if numpy.isnan(logp): - logp = -numpy.inf - return logp - - @property - def logposterior(self): - """Returns the log of the posterior of the current parameter values. - - The logprior is calculated first. If the logprior returns ``-inf`` - (possibly indicating a non-physical point), then the ``loglikelihood`` - is not called. - """ - logp = self.logprior - if logp == -numpy.inf: - return logp - else: - return logp + self.loglikelihood - - def prior_rvs(self, size=1, prior=None): - """Returns random variates drawn from the prior. - - If the ``sampling_params`` are different from the ``variable_params``, - the variates are transformed to the `sampling_params` parameter space - before being returned. - - Parameters - ---------- - size : int, optional - Number of random values to return for each parameter. Default is 1. - prior : JointDistribution, optional - Use the given prior to draw values rather than the saved prior. - - Returns - ------- - FieldArray - A field array of the random values. - """ - # draw values from the prior - if prior is None: - prior = self.prior_distribution - p0 = prior.rvs(size=size) - # transform if necessary - if self.sampling_transforms is not None: - ptrans = self.sampling_transforms.apply(p0) - # pull out the sampling args - p0 = FieldArray.from_arrays([ptrans[arg] - for arg in self.sampling_params], - names=self.sampling_params) - return p0 - - def _transform_params(self, **params): - """Applies sampling transforms and boundary conditions to parameters. - - Parameters - ---------- - \**params : - Key, value pairs of parameters to apply the transforms to. - - Returns - ------- - dict - A dictionary of the transformed parameters. - """ - # apply inverse transforms to go from sampling parameters to - # variable args - if self.sampling_transforms is not None: - params = self.sampling_transforms.apply(params, inverse=True) - # apply boundary conditions - params = self.prior_distribution.apply_boundary_conditions(**params) - return params - - # - # Methods for initiating from a config file. - # - @staticmethod - def extra_args_from_config(cp, section, skip_args=None, dtypes=None): - """Gets any additional keyword in the given config file. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - section : str - The name of the section to read. - skip_args : list of str, optional - Names of arguments to skip. - dtypes : dict, optional - A dictionary of arguments -> data types. If an argument is found - in the dict, it will be cast to the given datatype. Otherwise, the - argument's value will just be read from the config file (and thus - be a string). - - Returns - ------- - dict - Dictionary of keyword arguments read from the config file. - """ - kwargs = {} - if dtypes is None: - dtypes = {} - if skip_args is None: - skip_args = [] - read_args = [opt for opt in cp.options(section) - if opt not in skip_args] - for opt in read_args: - val = cp.get(section, opt) - # try to cast the value if a datatype was specified for this opt - try: - val = dtypes[opt](val) - except KeyError: - pass - kwargs[opt] = val - return kwargs - - @staticmethod - def prior_from_config(cp, variable_params, static_params, prior_section, - constraint_section): - """Gets arguments and keyword arguments from a config file. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - variable_params : list - List of variable model parameter names. - static_params : dict - Dictionary of static model parameters and their values. - prior_section : str - Section to read prior(s) from. - constraint_section : str - Section to read constraint(s) from. - - Returns - ------- - pycbc.distributions.JointDistribution - The prior. - """ - # get prior distribution for each variable parameter - logging.info("Setting up priors for each parameter") - dists = distributions.read_distributions_from_config(cp, prior_section) - constraints = distributions.read_constraints_from_config( - cp, constraint_section, static_args=static_params) - return distributions.JointDistribution(variable_params, *dists, - constraints=constraints) - - @classmethod - def _init_args_from_config(cls, cp): - """Helper function for loading parameters. - - This retrieves the prior, variable parameters, static parameterss, - constraints, sampling transforms, and waveform transforms - (if provided). - - Parameters - ---------- - cp : ConfigParser - Config parser to read. - - Returns - ------- - dict : - Dictionary of the arguments. Has keys ``variable_params``, - ``static_params``, ``prior``, and ``sampling_transforms``. If - waveform transforms are in the config file, will also have - ``waveform_transforms``. - """ - section = "model" - prior_section = "prior" - vparams_section = 'variable_params' - sparams_section = 'static_params' - constraint_section = 'constraint' - # check that the name exists and matches - name = cp.get(section, 'name') - if name != cls.name: - raise ValueError("section's {} name does not match mine {}".format( - name, cls.name)) - # get model parameters - variable_params, static_params = distributions.read_params_from_config( - cp, prior_section=prior_section, vargs_section=vparams_section, - sargs_section=sparams_section) - # get prior - prior = cls.prior_from_config( - cp, variable_params, static_params, prior_section, - constraint_section) - args = {'variable_params': variable_params, - 'static_params': static_params, - 'prior': prior} - # try to load sampling transforms - try: - sampling_transforms = SamplingTransforms.from_config( - cp, variable_params) - except NoSectionError: - sampling_transforms = None - args['sampling_transforms'] = sampling_transforms - # get any waveform transforms - if any(cp.get_subsections('waveform_transforms')): - logging.info("Loading waveform transforms") - waveform_transforms = transforms.read_transforms_from_config( - cp, 'waveform_transforms') - args['waveform_transforms'] = waveform_transforms - else: - waveform_transforms = [] - # safety check for spins - # we won't do this if the following exists in the config file - ignore = "no_err_on_missing_cartesian_spins" - check_for_cartesian_spins(1, variable_params, static_params, - waveform_transforms, cp, ignore) - check_for_cartesian_spins(2, variable_params, static_params, - waveform_transforms, cp, ignore) - return args - - @classmethod - def from_config(cls, cp, **kwargs): - """Initializes an instance of this class from the given config file. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - \**kwargs : - All additional keyword arguments are passed to the class. Any - provided keyword will over ride what is in the config file. - """ - args = cls._init_args_from_config(cp) - # get any other keyword arguments provided in the model section - args.update(cls.extra_args_from_config(cp, "model", - skip_args=['name'])) - args.update(kwargs) - return cls(**args) - - def write_metadata(self, fp, group=None): - """Writes metadata to the given file handler. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - """ - attrs = fp.getattrs(group=group) - attrs['model'] = self.name - attrs['variable_params'] = list(map(str, self.variable_params)) - attrs['sampling_params'] = list(map(str, self.sampling_params)) - fp.write_kwargs_to_attrs(attrs, static_params=self.static_params) - - -def check_for_cartesian_spins(which, variable_params, static_params, - waveform_transforms, cp, ignore): - """Checks that if any spin parameters exist, cartesian spins also exist. - - This looks for parameters starting with ``spinN`` in the variable and - static params, where ``N`` is either 1 or 2 (specified by the ``which`` - argument). If any parameters are found with those names, the params and - the output of the waveform transforms are checked to see that there is - at least one of ``spinN(x|y|z)``. If not, a ``ValueError`` is raised. - - This check will not be done if the config file has an section given by - the ignore argument. - - Parameters - ---------- - which : {1, 2} - Which component to check for. Must be either 1 or 2. - variable_params : list - List of the variable parameters. - static_params : dict - The dictionary of static params. - waveform_transforms : list - List of the transforms that will be applied to the variable and - static params before being passed to the waveform generator. - cp : ConfigParser - The config file. - ignore : str - The section to check for in the config file. If the section is - present in the config file, the check will not be done. - """ - # don't do this check if the config file has the ignore section - if cp.has_section(ignore): - logging.info("[{}] found in config file; not performing check for " - "cartesian spin{} parameters".format(ignore, which)) - return - errmsg = ( - "Spin parameters {sp} found in variable/static " - "params for component {n}, but no Cartesian spin parameters ({cp}) " - "found in either the variable/static params or " - "the waveform transform outputs. Most waveform " - "generators only recognize Cartesian spin " - "parameters; without them, all spins are set to " - "zero. If you are using spherical spin coordinates, add " - "the following waveform_transform to your config file:\n\n" - "[waveform_transforms-spin{n}x+spin{n}y+spin{n}z]\n" - "name = spherical_to_cartesian\n" - "x = spin{n}x\n" - "y = spin{n}y\n" - "z = spin{n}z\n" - "radial = spin{n}_a\n" - "azimuthal = spin{n}_azimuthal\n" - "polar = spin{n}_polar\n\n" - "Here, spin{n}_a, spin{n}_azimuthal, and spin{n}_polar are the names " - "of your radial, azimuthal, and polar coordinates, respectively. " - "If you intentionally did not include Cartesian spin parameters, " - "(e.g., you are using a custom waveform or model) add\n\n" - "[{ignore}]\n\n" - "to your config file as an empty section and rerun. This check will " - "not be performed in that case.") - allparams = set(variable_params) | set(static_params.keys()) - spinparams = set(p for p in allparams - if p.startswith('spin{}'.format(which))) - if any(spinparams): - cartspins = set('spin{}{}'.format(which, coord) - for coord in ['x', 'y', 'z']) - # add any parameters to all params that will be output by waveform - # transforms - allparams = allparams.union(*[t.outputs for t in waveform_transforms]) - if not any(allparams & cartspins): - raise ValueError(errmsg.format(sp=', '.join(spinparams), - cp=', '.join(cartspins), - n=which, ignore=ignore)) diff --git a/pycbc/inference/models/.ipynb_checkpoints/base_data-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/base_data-checkpoint.py deleted file mode 100644 index a8b0de84632..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/base_data-checkpoint.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright (C) 2018 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# - -"""Base classes for mofdels with data. -""" - -import numpy -from abc import (ABCMeta, abstractmethod) -from .base import BaseModel - - -class BaseDataModel(BaseModel, metaclass=ABCMeta): - r"""Base class for models that require data and a waveform generator. - - This adds propeties for the log of the likelihood that the data contain - noise, ``lognl``, and the log likelihood ratio ``loglr``. - - Classes that inherit from this class must define ``_loglr`` and ``_lognl`` - functions, in addition to the ``_loglikelihood`` requirement inherited from - ``BaseModel``. - - Parameters - ---------- - variable_params : (tuple of) string(s) - A tuple of parameter names that will be varied. - data : dict - A dictionary of data, in which the keys are the detector names and the - values are the data. - recalibration : dict of pycbc.calibration.Recalibrate, optional - Dictionary of detectors -> recalibration class instances for - recalibrating data. - gates : dict of tuples, optional - Dictionary of detectors -> tuples of specifying gate times. The - sort of thing returned by `pycbc.gate.gates_from_cli`. - injection_file : str, optional - If an injection was added to the data, the name of the injection file - used. If provided, the injection parameters will be written to - file when ``write_metadata`` is called. - - \**kwargs : - All other keyword arguments are passed to ``BaseModel``. - - - See ``BaseModel`` for additional attributes and properties. - """ - - def __init__(self, variable_params, data, recalibration=None, gates=None, - injection_file=None, no_save_data=False, **kwargs): - self._data = None - self.data = data - self.recalibration = recalibration - self.no_save_data = no_save_data - self.gates = gates - self.injection_file = injection_file - super(BaseDataModel, self).__init__(variable_params, **kwargs) - - @property - def data(self): - """dict: Dictionary mapping detector names to data.""" - return self._data - - @data.setter - def data(self, data): - """Store a copy of the data.""" - self._data = {det: d.copy() for (det, d) in data.items()} - - @property - def _extra_stats(self): - """Adds ``loglr`` and ``lognl`` to the ``default_stats``.""" - return ['loglr', 'lognl'] - - @property - def lognl(self): - """The log likelihood of the model assuming the data is noise. - - This will initially try to return the ``current_stats.lognl``. - If that raises an ``AttributeError``, will call `_lognl`` to - calculate it and store it to ``current_stats``. - """ - return self._trytoget('lognl', self._lognl) - - @abstractmethod - def _lognl(self): - """Low-level function that calculates the lognl.""" - pass - - @property - def loglr(self): - """The log likelihood ratio at the current parameters. - - This will initially try to return the ``current_stats.loglr``. - If that raises an ``AttributeError``, will call `_loglr`` to - calculate it and store it to ``current_stats``. - """ - return self._trytoget('loglr', self._loglr, apply_transforms=True) - - @abstractmethod - def _loglr(self): - """Low-level function that calculates the loglr.""" - pass - - @property - def logplr(self): - """Returns the log of the prior-weighted likelihood ratio at the - current parameter values. - - The logprior is calculated first. If the logprior returns ``-inf`` - (possibly indicating a non-physical point), then ``loglr`` is not - called. - """ - logp = self.logprior - if logp == -numpy.inf: - return logp - else: - return logp + self.loglr - - @property - def detectors(self): - """list: Returns the detectors used.""" - return list(self._data.keys()) - - def write_metadata(self, fp, group=None): - """Adds data to the metadata that's written. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - """ - super().write_metadata(fp, group=group) - if not self.no_save_data: - fp.write_stilde(self.data, group=group) - # save injection parameters - if self.injection_file is not None: - fp.write_injections(self.injection_file, group=group) diff --git a/pycbc/inference/models/.ipynb_checkpoints/data_utils-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/data_utils-checkpoint.py deleted file mode 100644 index c51982ae69a..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/data_utils-checkpoint.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright (C) 2018 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - -"""Utilities for loading data for models. -""" - -import logging -from argparse import ArgumentParser -from time import sleep -import numpy -try: - from mpi4py import MPI -except ImportError: - MPI = None - -from pycbc.types import MultiDetOptionAction -from pycbc.psd import (insert_psd_option_group_multi_ifo, - from_cli_multi_ifos as psd_from_cli_multi_ifos, - verify_psd_options_multi_ifo) -from pycbc import strain -from pycbc.strain import (gates_from_cli, psd_gates_from_cli, - apply_gates_to_td, apply_gates_to_fd, - verify_strain_options_multi_ifo) -from pycbc import dq - - -def strain_from_cli_multi_ifos(*args, **kwargs): - """Wrapper around strain.from_cli_multi_ifos that tries a few times before - quiting. - - When running in a parallel environment, multiple concurrent queries to the - segment data base can cause time out errors. If that happens, this will - sleep for a few seconds, then try again a few times before giving up. - """ - count = 0 - while count < 3: - try: - return strain.from_cli_multi_ifos(*args, **kwargs) - except RuntimeError as e: - exception = e - count += 1 - sleep(10) - # if get to here, we've tries 3 times and still got an error, so exit - raise exception - - -# -# ============================================================================= -# -# Utilities for gravitational-wave data -# -# ============================================================================= -# -class NoValidDataError(Exception): - """This should be raised if a continous segment of valid data could not be - found. - """ - pass - - -def create_data_parser(): - """Creates an argument parser for loading GW data.""" - parser = ArgumentParser() - # add data options - parser.add_argument("--instruments", type=str, nargs="+", required=True, - help="Instruments to analyze, eg. H1 L1.") - parser.add_argument("--trigger-time", type=float, default=0., - help="Reference GPS time (at geocenter) from which " - "the (anlaysis|psd)-(start|end)-time options are " - "measured. The integer seconds will be used. " - "Default is 0; i.e., if not provided, " - "the analysis and psd times should be in GPS " - "seconds.") - parser.add_argument("--analysis-start-time", type=int, required=True, - nargs='+', action=MultiDetOptionAction, - metavar='IFO:TIME', - help="The start time to use for the analysis, " - "measured with respect to the trigger-time. " - "If psd-inverse-length is provided, the given " - "start time will be padded by half that length " - "to account for wrap-around effects.") - parser.add_argument("--analysis-end-time", type=int, required=True, - nargs='+', action=MultiDetOptionAction, - metavar='IFO:TIME', - help="The end time to use for the analysis, " - "measured with respect to the trigger-time. " - "If psd-inverse-length is provided, the given " - "end time will be padded by half that length " - "to account for wrap-around effects.") - parser.add_argument("--psd-start-time", type=int, default=None, - nargs='+', action=MultiDetOptionAction, - metavar='IFO:TIME', - help="Start time to use for PSD estimation, measured " - "with respect to the trigger-time.") - parser.add_argument("--psd-end-time", type=int, default=None, - nargs='+', action=MultiDetOptionAction, - metavar='IFO:TIME', - help="End time to use for PSD estimation, measured " - "with respect to the trigger-time.") - parser.add_argument("--data-conditioning-low-freq", type=float, - nargs="+", action=MultiDetOptionAction, - metavar='IFO:FLOW', dest="low_frequency_cutoff", - help="Low frequency cutoff of the data. Needed for " - "PSD estimation and when creating fake strain. " - "If not provided, will use the model's " - "low-frequency-cutoff.") - insert_psd_option_group_multi_ifo(parser) - strain.insert_strain_option_group_multi_ifo(parser, gps_times=False) - strain.add_gate_option_group(parser) - # add arguments for dq - dqgroup = parser.add_argument_group("Options for quering data quality " - "(DQ)") - dqgroup.add_argument('--dq-segment-name', default='DATA', - help='The status flag to query for data quality. ' - 'Default is "DATA".') - dqgroup.add_argument('--dq-source', choices=['any', 'GWOSC', 'dqsegdb'], - default='any', - help='Where to look for DQ information. If "any" ' - '(the default) will first try GWOSC, then ' - 'dqsegdb.') - dqgroup.add_argument('--dq-server', default='https://segments.ligo.org', - help='The server to use for dqsegdb.') - dqgroup.add_argument('--veto-definer', default=None, - help='Path to a veto definer file that defines ' - 'groups of flags, which themselves define a set ' - 'of DQ segments.') - return parser - - -def check_validtimes(detector, gps_start, gps_end, shift_to_valid=False, - max_shift=None, segment_name='DATA', - **kwargs): - r"""Checks DQ server to see if the given times are in a valid segment. - - If the ``shift_to_valid`` flag is provided, the times will be shifted left - or right to try to find a continous valid block nearby. The shifting starts - by shifting the times left by 1 second. If that does not work, it shifts - the times right by one second. This continues, increasing the shift time by - 1 second, until a valid block could be found, or until the shift size - exceeds ``max_shift``. - - If the given times are not in a continuous valid segment, or a valid block - cannot be found nearby, a ``NoValidDataError`` is raised. - - Parameters - ---------- - detector : str - The name of the detector to query; e.g., 'H1'. - gps_start : int - The GPS start time of the segment to query. - gps_end : int - The GPS end time of the segment to query. - shift_to_valid : bool, optional - If True, will try to shift the gps start and end times to the nearest - continous valid segment of data. Default is False. - max_shift : int, optional - The maximum number of seconds to try to shift left or right to find - a valid segment. Default is ``gps_end - gps_start``. - segment_name : str, optional - The status flag to query; passed to :py:func:`pycbc.dq.query_flag`. - Default is "DATA". - \**kwargs : - All other keyword arguments are passed to - :py:func:`pycbc.dq.query_flag`. - - Returns - ------- - use_start : int - The start time to use. If ``shift_to_valid`` is True, this may differ - from the given GPS start time. - use_end : int - The end time to use. If ``shift_to_valid`` is True, this may differ - from the given GPS end time. - """ - # expand the times checked encase we need to shift - if max_shift is None: - max_shift = int(gps_end - gps_start) - check_start = gps_start - max_shift - check_end = gps_end + max_shift - # if we're running in an mpi enviornment and we're not the parent process, - # we'll wait before quering the segment database. This will result in - # getting the segments from the cache, so as not to overload the database - if MPI is not None and (MPI.COMM_WORLD.Get_size() > 1 and - MPI.COMM_WORLD.Get_rank() != 0): - # we'll wait for 2 minutes - sleep(120) - validsegs = dq.query_flag(detector, segment_name, check_start, - check_end, cache=True, - **kwargs) - use_start = gps_start - use_end = gps_end - # shift if necessary - if shift_to_valid: - shiftsize = 1 - while (use_start, use_end) not in validsegs and shiftsize < max_shift: - # try shifting left - use_start = gps_start - shiftsize - use_end = gps_end - shiftsize - if (use_start, use_end) not in validsegs: - # try shifting right - use_start = gps_start + shiftsize - use_end = gps_end + shiftsize - shiftsize += 1 - # check that we have a valid range - if (use_start, use_end) not in validsegs: - raise NoValidDataError("Could not find a continous valid segment in " - "in detector {}".format(detector)) - return use_start, use_end - - -def detectors_with_valid_data(detectors, gps_start_times, gps_end_times, - pad_data=None, err_on_missing_detectors=False, - **kwargs): - r"""Determines which detectors have valid data. - - Parameters - ---------- - detectors : list of str - Names of the detector names to check. - gps_start_times : dict - Dictionary of detector name -> start time listing the GPS start times - of the segment to check for each detector. - gps_end_times : dict - Dictionary of detector name -> end time listing the GPS end times of - the segment to check for each detector. - pad_data : dict, optional - Dictionary of detector name -> pad time to add to the beginning/end of - the GPS start/end times before checking. A pad time for every detector - in ``detectors`` must be given. Default (None) is to apply no pad to - the times. - err_on_missing_detectors : bool, optional - If True, a ``NoValidDataError`` will be raised if any detector does not - have continous valid data in its requested segment. Otherwise, the - detector will not be included in the returned list of detectors with - valid data. Default is False. - \**kwargs : - All other keyword arguments are passed to ``check_validtimes``. - - Returns - ------- - dict : - A dictionary of detector name -> valid times giving the detectors with - valid data and their segments. If ``shift_to_valid`` was passed to - ``check_validtimes`` this may not be the same as the input segments. If - no valid times could be found for a detector (and - ``err_on_missing_detectors`` is False), it will not be included in the - returned dictionary. - """ - if pad_data is None: - pad_data = {det: 0 for det in detectors} - dets_with_data = {} - for det in detectors: - logging.info("Checking that %s has valid data in the requested " - "segment", det) - try: - pad = pad_data[det] - start, end = check_validtimes(det, gps_start_times[det]-pad, - gps_end_times[det]+pad, - **kwargs) - dets_with_data[det] = (start+pad, end-pad) - except NoValidDataError as e: - if err_on_missing_detectors: - raise e - logging.warning("WARNING: Detector %s will not be used in " - "the analysis, as it does not have " - "continuous valid data that spans the " - "segment [%d, %d).", det, gps_start_times[det]-pad, - gps_end_times[det]+pad) - return dets_with_data - - -def check_for_nans(strain_dict): - """Checks if any data in a dictionary of strains has NaNs. - - If any NaNs are found, a ``ValueError`` is raised. - - Parameters - ---------- - strain_dict : dict - Dictionary of detectors -> - :py:class:`pycbc.types.timeseries.TimeSeries`. - """ - for det, ts in strain_dict.items(): - if numpy.isnan(ts.numpy()).any(): - raise ValueError("NaN found in strain from {}".format(det)) - - -def data_opts_from_config(cp, section, filter_flow): - """Loads data options from a section in a config file. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file to read. - section : str - The section to read. All options in the section will be loaded as-if - they wre command-line arguments. - filter_flow : dict - Dictionary of detectors -> inner product low frequency cutoffs. - If a `data-conditioning-low-freq` cutoff wasn't provided for any - of the detectors, these values will be used. Otherwise, the - data-conditioning-low-freq must be less than the inner product cutoffs. - If any are not, a ``ValueError`` is raised. - - Returns - ------- - opts : parsed argparse.ArgumentParser - An argument parser namespace that was constructed as if the options - were specified on the command line. - """ - # convert the section options into a command-line options - optstr = cp.section_to_cli(section) - # create a fake parser to parse them - parser = create_data_parser() - # parse the options - opts = parser.parse_args(optstr.split()) - # figure out the times to use - logging.info("Determining analysis times to use") - opts.trigger_time = int(opts.trigger_time) - gps_start = opts.analysis_start_time.copy() - gps_end = opts.analysis_end_time.copy() - for det in opts.instruments: - gps_start[det] += opts.trigger_time - gps_end[det] += opts.trigger_time - if opts.psd_inverse_length[det] is not None: - pad = int(numpy.ceil(opts.psd_inverse_length[det] / 2)) - logging.info("Padding %s analysis start and end times by %d " - "(= psd-inverse-length/2) seconds to " - "account for PSD wrap around effects.", - det, pad) - else: - pad = 0 - gps_start[det] -= pad - gps_end[det] += pad - if opts.psd_start_time[det] is not None: - opts.psd_start_time[det] += opts.trigger_time - if opts.psd_end_time[det] is not None: - opts.psd_end_time[det] += opts.trigger_time - opts.gps_start_time = gps_start - opts.gps_end_time = gps_end - # check for the frequencies - low_freq_cutoff = filter_flow.copy() - if opts.low_frequency_cutoff: - # add in any missing detectors - low_freq_cutoff.update({det: opts.low_frequency_cutoff[det] - for det in opts.instruments - if opts.low_frequency_cutoff[det] is not None}) - # make sure the data conditioning low frequency cutoff is < than - # the matched filter cutoff - if any(low_freq_cutoff[det] > filter_flow[det] for det in filter_flow): - raise ValueError("data conditioning low frequency cutoff must " - "be less than the filter low frequency " - "cutoff") - opts.low_frequency_cutoff = low_freq_cutoff - - # verify options are sane - verify_psd_options_multi_ifo(opts, parser, opts.instruments) - verify_strain_options_multi_ifo(opts, parser, opts.instruments) - return opts - - -def data_from_cli(opts, check_for_valid_times=False, - shift_psd_times_to_valid=False, - err_on_missing_detectors=False): - """Loads the data needed for a model from the given command-line options. - - Gates specifed on the command line are also applied. - - Parameters - ---------- - opts : ArgumentParser parsed args - Argument options parsed from a command line string (the sort of thing - returned by `parser.parse_args`). - check_for_valid_times : bool, optional - Check that valid data exists in the requested gps times. Default is - False. - shift_psd_times_to_valid : bool, optional - If estimating the PSD from data, shift the PSD times to a valid - segment if needed. Default is False. - err_on_missing_detectors : bool, optional - Raise a NoValidDataError if any detector does not have valid data. - Otherwise, a warning is printed, and that detector is skipped. - - Returns - ------- - strain_dict : dict - Dictionary of detectors -> time series strain. - psd_strain_dict : dict or None - If ``opts.psd_(start|end)_time`` were set, a dctionary of - detectors -> time series data to use for PSD estimation. Otherwise, - ``None``. - """ - # get gates to apply - gates = gates_from_cli(opts) - psd_gates = psd_gates_from_cli(opts) - - # get strain time series - instruments = opts.instruments - - # validate times - if check_for_valid_times: - dets_with_data = detectors_with_valid_data( - instruments, opts.gps_start_time, opts.gps_end_time, - pad_data=opts.pad_data, - err_on_missing_detectors=err_on_missing_detectors, - shift_to_valid=False, - segment_name=opts.dq_segment_name, source=opts.dq_source, - server=opts.dq_server, veto_definer=opts.veto_definer) - # reset instruments to only be those with valid data - instruments = list(dets_with_data.keys()) - - strain_dict = strain_from_cli_multi_ifos(opts, instruments, - precision="double") - # apply gates if not waiting to overwhiten - if not opts.gate_overwhitened: - strain_dict = apply_gates_to_td(strain_dict, gates) - - # check that there aren't nans in the data - check_for_nans(strain_dict) - - # get strain time series to use for PSD estimation - # if user has not given the PSD time options then use same data as analysis - if opts.psd_start_time and opts.psd_end_time: - logging.info("Will generate a different time series for PSD " - "estimation") - if check_for_valid_times: - psd_times = detectors_with_valid_data( - instruments, opts.psd_start_time, opts.psd_end_time, - pad_data=opts.pad_data, - err_on_missing_detectors=err_on_missing_detectors, - shift_to_valid=shift_psd_times_to_valid, - segment_name=opts.dq_segment_name, source=opts.dq_source, - server=opts.dq_server, veto_definer=opts.veto_definer) - # remove detectors from the strain dict that did not have valid - # times for PSD estimation - for det in set(strain_dict.keys())-set(psd_times.keys()): - _ = strain_dict.pop(det) - # reset instruments to only be those with valid data - instruments = list(psd_times.keys()) - else: - psd_times = {det: (opts.psd_start_time[det], - opts.psd_end_time[det]) - for det in instruments} - psd_strain_dict = {} - for det, (psd_start, psd_end) in psd_times.items(): - opts.gps_start_time = psd_start - opts.gps_end_time = psd_end - psd_strain_dict.update( - strain_from_cli_multi_ifos(opts, [det], precision="double")) - # apply any gates - logging.info("Applying gates to PSD data") - psd_strain_dict = apply_gates_to_td(psd_strain_dict, psd_gates) - # check that there aren't nans in the psd data - check_for_nans(psd_strain_dict) - elif opts.psd_start_time or opts.psd_end_time: - raise ValueError("Must give psd-start-time and psd-end-time") - else: - psd_strain_dict = None - - # check that we have data left to analyze - if instruments == []: - raise NoValidDataError("No valid data could be found in any of the " - "requested instruments.") - - return strain_dict, psd_strain_dict - - -def fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict=None): - """Converts a dictionary of time series to the frequency domain, and gets - the PSDs. - - Parameters - ---------- - opts : ArgumentParser parsed args - Argument options parsed from a command line string (the sort of thing - returned by `parser.parse_args`). - strain_dict : dict - Dictionary of detectors -> time series data. - psd_strain_dict : dict, optional - Dictionary of detectors -> time series data to use for PSD estimation. - If not provided, will use the ``strain_dict``. This is - ignored if ``opts.psd_estimation`` is not set. See - :py:func:`pycbc.psd.psd_from_cli_multi_ifos` for details. - - Returns - ------- - stilde_dict : dict - Dictionary of detectors -> frequency series data. - psd_dict : dict - Dictionary of detectors -> frequency-domain PSDs. - """ - # FFT strain and save each of the length of the FFT, delta_f, and - # low frequency cutoff to a dict - stilde_dict = {} - length_dict = {} - delta_f_dict = {} - for det, tsdata in strain_dict.items(): - stilde_dict[det] = tsdata.to_frequencyseries() - length_dict[det] = len(stilde_dict[det]) - delta_f_dict[det] = stilde_dict[det].delta_f - - if psd_strain_dict is None: - psd_strain_dict = strain_dict - - # get PSD as frequency series - psd_dict = psd_from_cli_multi_ifos( - opts, length_dict, delta_f_dict, opts.low_frequency_cutoff, - list(psd_strain_dict.keys()), strain_dict=psd_strain_dict, - precision="double") - - return stilde_dict, psd_dict - - -def gate_overwhitened_data(stilde_dict, psd_dict, gates): - """Applies gates to overwhitened data. - - Parameters - ---------- - stilde_dict : dict - Dictionary of detectors -> frequency series data to apply the gates to. - psd_dict : dict - Dictionary of detectors -> PSD to use for overwhitening. - gates : dict - Dictionary of detectors -> gates. - - Returns - ------- - dict : - Dictionary of detectors -> frequency series data with the gates - applied after overwhitening. The returned data is not overwhitened. - """ - logging.info("Applying gates to overwhitened data") - # overwhiten the data - out = {} - for det in gates: - out[det] = stilde_dict[det] / psd_dict[det] - # now apply the gate - out = apply_gates_to_fd(out, gates) - # now unwhiten - for det in gates: - out[det] *= psd_dict[det] - return out diff --git a/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-checkpoint.py deleted file mode 100644 index 3c10bb32aa9..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-checkpoint.py +++ /dev/null @@ -1,1110 +0,0 @@ -# Copyright (C) 2020 Collin Capano and Shilpa Kastha -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - -"""This module provides model classes that assume the noise is Gaussian and -introduces a gate to remove given times from the data, using the inpainting -method to fill the removed part such that it does not enter the likelihood. -""" - -from abc import abstractmethod -import logging -import numpy -import scipy -from scipy import special - -from pycbc.waveform import (NoWaveformError, FailedWaveformError) -from pycbc.types import FrequencySeries -from pycbc.detector import Detector -from pycbc.pnutils import hybrid_meco_frequency -from pycbc.psd import interpolate -from pycbc import types -from pycbc.waveform.utils import time_from_frequencyseries -from pycbc.waveform import generator -from pycbc.filter import highpass -from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator) -from .base_data import BaseDataModel -from .data_utils import fd_data_from_strain_dict - - -class BaseGatedGaussian(BaseGaussianNoise): - r"""Base model for gated gaussian. - - Provides additional routines for applying a time-domain gate to data. - See :py:class:`GatedGaussianNoise` for more details. - """ - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, highpass_waveforms=False, **kwargs): - # we'll want the time-domain data, so store that - self._td_data = {} - # cache the current projection for debugging - self.current_proj = {} - self.current_nproj = {} - # cache the overwhitened data - self._overwhitened_data = {} - # cache the current gated data - self._gated_data = {} - # cache covariance matrix and normalization terms - self._Rss = {} - self._cov = {} - self._lognorm = {} - # cache samples and linear regression for determinant extrapolation - self._cov_samples = {} - self._cov_regressions = {} - # highpass waveforms with the given frequency - self.highpass_waveforms = highpass_waveforms - if self.highpass_waveforms: - logging.info("Will highpass waveforms at %f Hz", - highpass_waveforms) - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - - @classmethod - def from_config(cls, cp, data_section='data', data=None, psds=None, - **kwargs): - """Adds highpass filtering to keyword arguments based on config file. - """ - if cp.has_option(data_section, 'strain-high-pass') and \ - 'highpass_waveforms' not in kwargs: - kwargs['highpass_waveforms'] = float(cp.get(data_section, - 'strain-high-pass')) - return super().from_config(cp, data_section=data_section, - data=data, psds=psds, - **kwargs) - - @BaseDataModel.data.setter - def data(self, data): - """Store a copy of the FD and TD data.""" - BaseDataModel.data.fset(self, data) - # store the td version - self._td_data = {det: d.to_timeseries() for det, d in data.items()} - - @property - def td_data(self): - """The data in the time domain.""" - return self._td_data - - @BaseGaussianNoise.psds.setter - def psds(self, psds): - """Sets the psds, and calculates the weight and norm from them. - The data and the low and high frequency cutoffs must be set first. - """ - # check that the data has been set - if self._data is None: - raise ValueError("No data set") - if self._f_lower is None: - raise ValueError("low frequency cutoff not set") - if self._f_upper is None: - raise ValueError("high frequency cutoff not set") - # make sure the relevant caches are cleared - self._psds.clear() - self._invpsds.clear() - self._gated_data.clear() - # store the psds - for det, d in self._data.items(): - if psds is None: - # No psd means assume white PSD - p = FrequencySeries(numpy.ones(int(self._N/2+1)), - delta_f=d.delta_f) - else: - # copy for storage - p = psds[det].copy() - self._psds[det] = p - # we'll store the weight to apply to the inner product - invp = 1./p - self._invpsds[det] = invp - # store the autocorrelation function and covariance matrix for each detector - Rss = p.astype(types.complex_same_precision_as(p)).to_timeseries() - cov = scipy.linalg.toeplitz(Rss/2) # full covariance matrix - self._Rss[det] = Rss - self._cov[det] = cov - # calculate and store the linear regressions to extrapolate determinant values - if self.normalize: - samples, fit = self.logdet_fit(cov, p) - self._cov_samples[det] = samples - self._cov_regressions[det] = fit - self._overwhitened_data = self.whiten(self.data, 2, inplace=False) - - def logdet_fit(self, cov, p): - """Construct a linear regression from a sample of truncated covariance matrices. - - Returns the sample points used for linear fit generation as well as the linear fit parameters. - """ - # initialize lists for matrix sizes and determinants - sample_sizes = [] - sample_dets = [] - # set sizes of sample matrices; ensure exact calculations are only on small matrices - s = cov.shape[0] - max_size = 8192 - if s > max_size: - sample_sizes = [s, max_size, max_size//2, max_size//4] - else: - sample_sizes = [s, s//2, s//4, s//8] - for i in sample_sizes: - # calculate logdet of the full matrix using circulant eigenvalue approximation - if i == s: - ld = 2*numpy.log(p/(2*p.delta_t)).sum() - sample_dets.append(ld) - # generate three more sample matrices using exact calculations - else: - gate_size = s - i - start = (s - gate_size)//2 - end = start + gate_size - tc = numpy.delete(numpy.delete(cov, slice(start, end), 0), slice(start, end), 1) - ld = numpy.linalg.slogdet(tc)[1] - sample_dets.append(ld) - # generate a linear regression using the four points (size, logdet) - x = numpy.vstack([sample_sizes, numpy.ones(len(sample_sizes))]).T - m, b = numpy.linalg.lstsq(x, sample_dets, rcond=None)[0] - return (sample_sizes, sample_dets), (m, b) - - def gate_indices(self, det): - """Calculate the indices corresponding to start and end of gate. - """ - # get time series start and delta_t - ts = self._Rss[det] - start_time_gc = float(self.td_data[det].start_time) # need float for conversion input - delta_t = ts.delta_t - # get gate start and length from get_gate_times - gate_start, gate_length = self.get_gate_times()[det] - # convert to indices - lindex = float(gate_start - start_time_gc) // delta_t - rindex = lindex + (gate_length // delta_t) - lindex = lindex if lindex >= 0 else 0 - rindex = rindex if rindex <= len(ts) else len(ts) - return lindex, rindex - - def det_lognorm(self, det, start_index=None, end_index=None): - """Calculate the normalization term from the truncated covariance matrix. - Determinant is estimated using a linear fit to logdet vs truncated matrix size - """ - if not self.normalize: - return 0 - try: - # check if the key already exists; if so, return its value - lognorm = self._lognorm[(det, start_index, end_index)] - except KeyError: - # if not, extrapolate the normalization term - start_index, end_index = self.gate_indices(det) - # get the size of the matrix - cov = self._cov[det] - n = cov.shape[0] - trunc_size = n - (end_index - start_index) - # call the linear regression - m, b = self._cov_regressions[det] - # extrapolate from linear fit - ld = m*trunc_size + b - lognorm = -0.5*(numpy.log(2*numpy.pi)*trunc_size + ld) # full normalization term - # cache the result - self._lognorm[(det, start_index, end_index)] = lognorm - return lognorm - - @property - def normalize(self): - """Determines if the loglikelihood includes the normalization term. - """ - return self._normalize - - ### This is called before psds, so doing direct calls to self._cov, self._psds, etc. throws an error for now ### - - @normalize.setter - def normalize(self, normalize): - """Clears the current stats if the normalization state is changed. - """ - self._normalize = normalize - # set covariance det linear regression iff normalize is set to true and the respective dicts are empty - # if self.normalize and self._cov_samples == {} and self._cov_regressions == {}: - # for det, d in self._data.items(): - # cov = self._cov[det] - # samples, fit = self.logdet_fit(cov, p) - # self._cov_samples[det] = samples - # self._cov_regressions[det] = fit - - @staticmethod - def _nowaveform_logl(): - """Convenience function to set logl values if no waveform generated. - """ - return -numpy.inf - - def _loglr(self): - r"""Computes the log likelihood ratio. - Returns - ------- - float - The value of the log likelihood ratio evaluated at the given point. - """ - return self._loglikelihood() - self._lognl() - - def whiten(self, data, whiten, inplace=False): - """Whitens the given data. - - Parameters - ---------- - data : dict - Dictionary of detector names -> FrequencySeries. - whiten : {0, 1, 2} - Integer indicating what level of whitening to apply. Levels are: - 0: no whitening; 1: whiten; 2: overwhiten. - inplace : bool, optional - If True, modify the data in place. Otherwise, a copy will be - created for whitening. - - - Returns - ------- - dict : - Dictionary of FrequencySeries after the requested whitening has - been applied. - """ - if not inplace: - data = {det: d.copy() for det, d in data.items()} - if whiten: - for det, dtilde in data.items(): - invpsd = self._invpsds[det] - if whiten == 1: - dtilde *= invpsd**0.5 - elif whiten == 2: - dtilde *= invpsd - else: - raise ValueError("whiten must be either 0, 1, or 2") - return data - - def get_waveforms(self): - """The waveforms generated using the current parameters. - - If the waveforms haven't been generated yet, they will be generated, - resized to the same length as the data, and cached. If the - ``highpass_waveforms`` attribute is set, a highpass filter will - also be applied to the waveforms. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - if self._current_wfs is None: - params = self.current_params - wfs = self.waveform_generator.generate(**params) - for det, h in wfs.items(): - # make the same length as the data - h.resize(len(self.data[det])) - # apply high pass - if self.highpass_waveforms: - h = highpass( - h.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - wfs[det] = h - self._current_wfs = wfs - return self._current_wfs - - @abstractmethod - def get_gated_waveforms(self): - """Generates and gates waveforms using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - pass - - def get_residuals(self): - """Generates the residuals ``d-h`` using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - wfs = self.get_waveforms() - - out = {} - for det, h in wfs.items(): - d = self.data[det] - out[det] = d - h - return out - - def get_data(self): - """Return a copy of the data. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - return {det: d.copy() for det, d in self.data.items()} - - def get_gated_data(self): - """Return a copy of the gated data. - - The gated data will be cached for faster retrieval. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - gate_times = self.get_gate_times() - out = {} - for det, d in self.td_data.items(): - # make sure the cache at least has the detectors in it - try: - cache = self._gated_data[det] - except KeyError: - cache = self._gated_data[det] = {} - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - try: - dtilde = cache[gatestartdelay, dgatedelay] - except KeyError: - # doesn't exist yet, or the gate times changed - cache.clear() - d = d.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - dtilde = d.to_frequencyseries() - # save for next time - cache[gatestartdelay, dgatedelay] = dtilde - out[det] = dtilde - return out - - def get_gate_times(self): - """Gets the time to apply a gate based on the current sky position. - - If the parameter ``gatefunc`` is set to ``'hmeco'``, the gate times - will be calculated based on the hybrid MECO of the given set of - parameters; see ``get_gate_times_hmeco`` for details. Otherwise, the - gate times will just be retrieved from the ``t_gate_start`` and - ``t_gate_end`` parameters. - - .. warning:: - Since the normalization of the likelihood is currently not - being calculated, it is recommended that you do not use - ``gatefunc``, instead using fixed gate times. - - Returns - ------- - dict : - Dictionary of detector names -> (gate start, gate width) - """ - params = self.current_params - try: - gatefunc = self.current_params['gatefunc'] - except KeyError: - gatefunc = None - if gatefunc == 'hmeco': - return self.get_gate_times_hmeco() - # gate input for ringdown analysis which consideres a start time - # and an end time - gatestart = params['t_gate_start'] - gateend = params['t_gate_end'] - # we'll need the sky location for determining time shifts - ra = self.current_params['ra'] - dec = self.current_params['dec'] - gatetimes = {} - for det in self._invpsds: - thisdet = Detector(det) - # account for the time delay between the waveforms of the - # different detectors - gatestartdelay = gatestart + thisdet.time_delay_from_earth_center(ra, dec, gatestart) - gateenddelay = gateend + thisdet.time_delay_from_earth_center(ra, dec, gateend) - dgatedelay = gateenddelay - gatestartdelay - gatetimes[det] = (gatestartdelay, dgatedelay) - return gatetimes - - def get_gate_times_hmeco(self): - """Gets the time to apply a gate based on the current sky position. - Returns - ------- - dict : - Dictionary of detector names -> (gate start, gate width) - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get waveform parameters - params = self.current_params - spin1 = params['spin1z'] - spin2 = params['spin2z'] - # gate input for ringdown analysis which consideres a start time - # and an end time - dgate = params['gate_window'] - meco_f = hybrid_meco_frequency(params['mass1'], params['mass2'], spin1, spin2) - # figure out the gate times - gatetimes = {} - for det, h in wfs.items(): - invpsd = self._invpsds[det] - h.resize(len(invpsd)) - ht = h.to_timeseries() - f_low = int((self._f_lower[det]+1)/h.delta_f) - sample_freqs = h.sample_frequencies[f_low:].numpy() - f_idx = numpy.where(sample_freqs <= meco_f)[0][-1] - # find time corresponding to meco frequency - t_from_freq = time_from_frequencyseries( - h[f_low:], sample_frequencies=sample_freqs) - if t_from_freq[f_idx] > 0: - gatestartdelay = t_from_freq[f_idx] + float(t_from_freq.epoch) - else: - gatestartdelay = t_from_freq[f_idx] + ht.sample_times[-1] - gatestartdelay = min(gatestartdelay, params['t_gate_start']) - gatetimes[det] = (gatestartdelay, dgate) - return gatetimes - - def _lognl(self): - """Calculates the log of the noise likelihood. - """ - # clear variables - lognl = 0. - self._det_lognls.clear() - # get the times of the gates - gate_times = self.get_gate_times() - self.current_nproj.clear() - for det, invpsd in self._invpsds.items(): - start_index, end_index = self.gate_indices(det) - norm = self.det_lognorm(det, start_index, end_index) # linear estimation - gatestartdelay, dgatedelay = gate_times[det] - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # gate the data - data = self.td_data[det] - gated_dt = data.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - self.current_nproj[det] = (gated_dt.proj, gated_dt.projslc) - # convert to the frequency series - gated_d = gated_dt.to_frequencyseries() - # overwhiten - gated_d *= invpsd - d = self.data[det] - # inner product - ip = 4 * invpsd.delta_f * d[slc].inner(gated_d[slc]).real # - dd = norm - 0.5*ip - # store - self._det_lognls[det] = dd - lognl += dd - return float(lognl) - - def det_lognl(self, det): - # make sure lognl has been called - _ = self._trytoget('lognl', self._lognl) - # the det_lognls dict should have been updated, so can call it now - return self._det_lognls[det] - - @staticmethod - def _fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict): - """Wrapper around :py:func:`data_utils.fd_data_from_strain_dict`. - - Ensures that if the PSD is estimated from data, the inverse spectrum - truncation uses a Hann window, and that the low frequency cutoff is - zero. - """ - if opts.psd_inverse_length and opts.invpsd_trunc_method is None: - # make sure invpsd truncation is set to hanning - logging.info("Using Hann window to truncate inverse PSD") - opts.invpsd_trunc_method = 'hann' - lfs = None - if opts.psd_estimation: - # make sure low frequency cutoff is zero - logging.info("Setting low frequency cutoff of PSD to 0") - lfs = opts.low_frequency_cutoff.copy() - opts.low_frequency_cutoff = {d: 0. for d in lfs} - out = fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict) - # set back - if lfs is not None: - opts.low_frequency_cutoff = lfs - return out - - def write_metadata(self, fp, group=None): - """Adds writing the psds, and analyzed detectors. - - The analyzed detectors, their analysis segments, and the segments - used for psd estimation are written as - ``analyzed_detectors``, ``{{detector}}_analysis_segment``, and - ``{{detector}}_psd_segment``, respectively. These are either written - to the specified ``group``'s attrs, or to the top level attrs if - ``group`` is None. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - """ - BaseDataModel.write_metadata(self, fp) - attrs = fp.getattrs(group=group) - # write the analyzed detectors and times - attrs['analyzed_detectors'] = self.detectors # store fitting values here - for det, data in self.data.items(): - key = '{}_analysis_segment'.format(det) - attrs[key] = [float(data.start_time), float(data.end_time)] - # store covariance determinant extrapolation information (checkpoint) - if self.normalize: - attrs['{}_cov_sample'.format(det)] = self._cov_samples[det] - attrs['{}_cov_regression'.format(det)] = self._cov_regressions[det] - if self._psds is not None and not self.no_save_data: - fp.write_psd(self._psds, group=group) - # write the times used for psd estimation (if they were provided) - for det in self.psd_segments: - key = '{}_psd_segment'.format(det) - attrs[key] = list(map(float, self.psd_segments[det])) - # save the frequency cutoffs - for det in self.detectors: - attrs['{}_likelihood_low_freq'.format(det)] = self._f_lower[det] - if self._f_upper[det] is not None: - attrs['{}_likelihood_high_freq'.format(det)] = \ - self._f_upper[det] - - -class GatedGaussianNoise(BaseGatedGaussian): - r"""Model that applies a time domain gate, assuming stationary Gaussian - noise. - - The gate start and end times are set by providing ``t_gate_start`` and - ``t_gate_end`` parameters, respectively. This will cause the gated times - to be excised from the analysis. For more details on the likelihood - function and its derivation, see - `arXiv:2105.05238 `_. - - .. warning:: - The normalization of the likelihood depends on the gate times. However, - at the moment, the normalization is not calculated, as it depends on - the determinant of the truncated covariance matrix (see Eq. 4 of - arXiv:2105.05238). For this reason it is recommended that you only - use this model for fixed gate times. - - """ - name = 'gated_gaussian_noise' - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, **kwargs): - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - # create the waveform generator - self.waveform_generator = create_waveform_generator( - self.variable_params, self.data, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - gates=self.gates, **self.static_params) - - @property - def _extra_stats(self): - """No extra stats are stored.""" - return [] - - ### This needs to be called after running model.update() for changes to take effect ### - - def _loglikelihood(self): - r"""Computes the log likelihood after removing the power within the - given time window, - - .. math:: - \log p(d|\Theta) = -\frac{1}{2} \sum_i - \left< d_i - h_i(\Theta) | d_i - h_i(\Theta) \right>, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood. - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get the times of the gates - gate_times = self.get_gate_times() - logl = 0. - self.current_proj.clear() - for det, h in wfs.items(): - invpsd = self._invpsds[det] - start_index, end_index = self.gate_indices(det) - norm = self.det_lognorm(det, start_index, end_index) - gatestartdelay, dgatedelay = gate_times[det] - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # calculate the residual - data = self.td_data[det] - ht = h.to_timeseries() - res = data - ht - rtilde = res.to_frequencyseries() - gated_res = res.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - self.current_proj[det] = (gated_res.proj, gated_res.projslc) - gated_rtilde = gated_res.to_frequencyseries() - # overwhiten - gated_rtilde *= invpsd - rr = 4 * invpsd.delta_f * rtilde[slc].inner(gated_rtilde[slc]).real - logl += norm - 0.5*rr - return float(logl) - - @property - def multi_signal_support(self): - """ The list of classes that this model supports in a multi-signal - likelihood - """ - return [type(self)] - - def multi_loglikelihood(self, models): - """ Calculate a multi-model (signal) likelihood - """ - # Generate the waveforms for each submodel - wfs = [] - for m in models + [self]: - # temp fix for wfs in combined run - # set static params 'zero_before_gate' or 'zero_after_gate' to specify portion of wf to zero out - wf = m.get_waveforms() - # if params don't exist, set them to False - # if they do exist, their value doesn't matter; strings always return True - try: - m.current_params['zero_before_gate'] - except KeyError: - m.current_params['zero_before_gate'] = False - try: - m.current_params['zero_after_gate'] - except KeyError: - m.current_params['zero_after_gate'] = False - - if m.current_params['zero_before_gate'] or m.current_params['zero_after_gate']: - gate_times = m.get_gate_times() - for d in wf: - ts = wf[d] - start = gate_times[d][0] - gpsidx = (float(start) - float(ts.start_time))//ts.delta_t - ts = ts.to_timeseries() - # zero out inspiral wf after gate - if m.current_params['zero_after_gate']: - ts[int(gpsidx):] *= 0 - # zero out ringdown wf before gate - if m.current_params['zero_before_gate']: - ts[:int(gpsidx)] *= 0 - ts = ts.to_frequencyseries() - wf[d] = ts - wfs.append(wf) - - # combine into a single waveform - combine = {} - for det in self.data: - # get max waveform length - mlen = max([len(x[det]) for x in wfs]) - [x[det].resize(mlen) for x in wfs] - combine[det] = sum([x[det] for x in wfs]) - - self._current_wfs = combine - return self._loglikelihood() - - def get_gated_waveforms(self): - wfs = self.get_waveforms() - gate_times = self.get_gate_times() - out = {} - - # temp fix for hierarchical runs - # zeroes out pre-merger for ringdown, post-merger for inspiral - # will need a more elegant solution later; for now copy code from multi_loglikelihood - try: - self.current_params['zero_before_gate'] - except KeyError: - self.current_params['zero_before_gate'] = False - try: - self.current_params['zero_after_gate'] - except KeyError: - self.current_params['zero_after_gate'] = False - - if self.current_params['zero_before_gate'] or self.current_params['zero_after_gate']: - for d in wfs: - ts = wfs[d] - start = gate_times[d][0] - gpsidx = (float(start) - float(ts.start_time))//ts.delta_t - ts = ts.to_timeseries() - # zero out inspiral wf after gate - if self.current_params['zero_after_gate']: - ts[int(gpsidx):] *= 0 - # zero out ringdown wf before gate - if self.current_params['zero_before_gate']: - ts[:int(gpsidx)] *= 0 - ts = ts.to_frequencyseries() - wfs[d] = ts - - # apply the gate - for det, h in wfs.items(): - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - ht = h.to_timeseries() - ht = ht.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=False, - invpsd=invpsd, method='paint') - h = ht.to_frequencyseries() - out[det] = h - return out - - def get_gated_residuals(self): - """Generates the gated residuals ``d-h`` using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - params = self.current_params - wfs = self.waveform_generator.generate(**params) - gate_times = self.get_gate_times() - out = {} - for det, h in wfs.items(): - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - data = self.td_data[det] - ht = h.to_timeseries() - res = data - ht - res = res.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - res = res.to_frequencyseries() - out[det] = res - return out - - -class GatedGaussianMargPol(BaseGatedGaussian): - r"""Gated gaussian model with numerical marginalization over polarization. - - This implements the GatedGaussian likelihood with an explicit numerical - marginalization over polarization angle. This is accomplished using - a fixed set of integration points distribution uniformation between - 0 and 2pi. By default, 1000 integration points are used. - The 'polarization_samples' argument can be passed to set an alternate - number of integration points. - """ - name = 'gated_gaussian_margpol' - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, - polarization_samples=1000, **kwargs): - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - # the polarization parameters - self.polarization_samples = polarization_samples - self.pol = numpy.linspace(0, 2*numpy.pi, self.polarization_samples) - self.dets = {} - # create the waveform generator - self.waveform_generator = create_waveform_generator( - self.variable_params, self.data, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - generator_class=generator.FDomainDetFrameTwoPolGenerator, - **self.static_params) - - def get_waveforms(self): - if self._current_wfs is None: - params = self.current_params - wfs = self.waveform_generator.generate(**params) - for det, (hp, hc) in wfs.items(): - # make the same length as the data - hp.resize(len(self.data[det])) - hc.resize(len(self.data[det])) - # apply high pass - if self.highpass_waveforms: - hp = highpass( - hp.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - hc = highpass( - hc.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - wfs[det] = (hp, hc) - self._current_wfs = wfs - return self._current_wfs - - def get_gated_waveforms(self): - wfs = self.get_waveforms() - gate_times = self.get_gate_times() - out = {} - - # temp fix for hierarchical runs - # zero out pre-merger for ringdown, post-merger for inspiral - # need a more elegant solution later; for now just copy code from multi_loglikelihood - try: - self.current_params['zero_before_gate'] - except KeyError: - self.current_params['zero_before_gate'] = False - try: - self.current_params['zero_after_gate'] - except KeyError: - self.current_params['zero_after_gate'] = False - - if self.current_params['zero_before_gate'] or self.current_params['zero_after_gate']: - for d in wfs: - tsp, tsc = wfs[d] - start = gate_times[d][0] - # plus and cross pols probably have the same gpsidx, but get both just to be safe - gpsidxp = (float(start) - float(tsp.start_time))//tsp.delta_t - gpsidxc = (float(start) - float(tsc.start_time))//tsc.delta_t - tsp = tsp.to_timeseries() - tsc = tsc.to_timeseries() - # zero out inspiral wf after gate - if self.current_params['zero_after_gate']: - tsp[int(gpsidxp):] *= 0 - tsc[int(gpsidxc):] *= 0 - # zero out ringdown wf before gate - if self.current_params['zero_before_gate']: - tsp[:int(gpsidxp)] *= 0 - tsc[:int(gpsidxc)] *= 0 - tsp = tsp.to_frequencyseries() - tsc = tsc.to_frequencyseries() - wfs[d] = (tsp, tsc) - - for det in wfs: - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - # the waveforms are a dictionary of (hp, hc) - pols = [] - for h in wfs[det]: - ht = h.to_timeseries() - try: - ht = ht.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=False, - invpsd=invpsd, method='paint') - h = ht.to_frequencyseries() - except ValueError as e: - numpy.save('fail_params.out', self.current_params, allow_pickle=True) - ht.save('fail_wf.hdf') - raise e - pols.append(h) - out[det] = tuple(pols) - return out - - def get_gate_times_hmeco(self): - """Gets the time to apply a gate based on the current sky position. - Returns - ------- - dict : - Dictionary of detector names -> (gate start, gate width) - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get waveform parameters - params = self.current_params - spin1 = params['spin1z'] - spin2 = params['spin2z'] - # gate input for ringdown analysis which consideres a start time - # and an end time - dgate = params['gate_window'] - meco_f = hybrid_meco_frequency(params['mass1'], params['mass2'], spin1, spin2) - # figure out the gate times - gatetimes = {} - # for now only calculating time from plus polarization; should be all that's necessary - for det, (hp, hc) in wfs.items(): - invpsd = self._invpsds[det] - hp.resize(len(invpsd)) - ht = hp.to_timeseries() - f_low = int((self._f_lower[det]+1)/hp.delta_f) - sample_freqs = hp.sample_frequencies[f_low:].numpy() - f_idx = numpy.where(sample_freqs <= meco_f)[0][-1] - # find time corresponding to meco frequency - t_from_freq = time_from_frequencyseries( - hp[f_low:], sample_frequencies=sample_freqs) - if t_from_freq[f_idx] > 0: - gatestartdelay = t_from_freq[f_idx] + float(t_from_freq.epoch) - else: - gatestartdelay = t_from_freq[f_idx] + ht.sample_times[-1] - gatestartdelay = min(gatestartdelay, params['t_gate_start']) - gatetimes[det] = (gatestartdelay, dgate) - return gatetimes - - @property - def _extra_stats(self): - """Adds the maxL polarization and corresponding likelihood.""" - return ['maxl_polarization', 'maxl_logl'] - - def _loglikelihood(self): - r"""Computes the log likelihood after removing the power within the - given time window, - - .. math:: - \log p(d|\Theta) = -\frac{1}{2} \sum_i - \left< d_i - h_i(\Theta) | d_i - h_i(\Theta) \right>, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood. - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get the gated waveforms and data - gated_wfs = self.get_gated_waveforms() - gated_data = self.get_gated_data() - # cycle over - loglr = 0. - lognl = 0. - for det, (hp, hc) in wfs.items(): - # get the antenna patterns - if det not in self.dets: - self.dets[det] = Detector(det) - fp, fc = self.dets[det].antenna_pattern(self.current_params['ra'], - self.current_params['dec'], - self.pol, - self.current_params['tc']) - start_index, end_index = self.gate_indices(det) - norm = self.det_lognorm(det, start_index, end_index) - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # get the gated values - gated_hp, gated_hc = gated_wfs[det] - gated_d = gated_data[det] - # we'll overwhiten the ungated data and waveforms for computing - # inner products - d = self._overwhitened_data[det] - # overwhiten the hp and hc - invpsd = self._invpsds[det] - hp = hp*invpsd - hc = hc*invpsd - # get the various gated inner products - hpd = hp[slc].inner(gated_d[slc]).real # - hcd = hc[slc].inner(gated_d[slc]).real # - dhp = d[slc].inner(gated_hp[slc]).real # - dhc = d[slc].inner(gated_hc[slc]).real # - hphp = hp[slc].inner(gated_hp[slc]).real # - hchc = hc[slc].inner(gated_hc[slc]).real # - hphc = hp[slc].inner(gated_hc[slc]).real # - hchp = hc[slc].inner(gated_hp[slc]).real # - dd = d[slc].inner(gated_d[slc]).real # - # since the antenna patterns are real, - # /2 + /2 = fp*(/2 + /2) - # + fc*(/2 + /2) - hd = fp*(hpd + dhp) + fc*(hcd + dhc) - # /2 = /2 - # = fp*fp*/2 + fc*fc*/2 - # + fp*fc*/2 + fc*fp*/2 - hh = fp*fp*hphp + fc*fc*hchc + fp*fc*(hphc + hchp) - # sum up; note that the factor is 2df instead of 4df to account - # for the factor of 1/2 - loglr += norm + 2*invpsd.delta_f*(hd - hh) - lognl += -2 * invpsd.delta_f * dd - # store the maxl polarization - idx = loglr.argmax() - setattr(self._current_stats, 'maxl_polarization', self.pol[idx]) - setattr(self._current_stats, 'maxl_logl', loglr[idx] + lognl) - # compute the marginalized log likelihood - marglogl = special.logsumexp(loglr) + lognl - numpy.log(len(self.pol)) - return float(marglogl) - - @property - def multi_signal_support(self): - """ The list of classes that this model supports in a multi-signal - likelihood - """ - return [type(self)] - - def multi_loglikelihood(self, models): - """ Calculate a multi-model (signal) likelihood - """ - # Generate the waveforms for each submodel - wfs = [] - for m in models + [self]: - # temp fix for wfs in combined run - # set static params 'zero_before_gate' or 'zero_after_gate' to specify portion of wf to zero out - wf = m.get_waveforms() - # if params don't exist, set them to False - # if they do exist, their value doesn't matter; strings always return True - try: - m.current_params['zero_before_gate'] - except KeyError: - m.current_params['zero_before_gate'] = False - try: - m.current_params['zero_after_gate'] - except KeyError: - m.current_params['zero_after_gate'] = False - - if m.current_params['zero_before_gate'] or m.current_params['zero_after_gate']: - gate_times = m.get_gate_times() - for d in wf: - tsp, tsc = wf[d] - start = gate_times[d][0] - # plus and cross pols probably have the same gpsidx, but get both just to be safe - gpsidxp = (float(start) - float(tsp.start_time))//tsp.delta_t - gpsidxc = (float(start) - float(tsc.start_time))//tsc.delta_t - tsp = tsp.to_timeseries() - tsc = tsc.to_timeseries() - # zero out inspiral wf after gate - if m.current_params['zero_after_gate']: - tsp[int(gpsidxp):] *= 0 - tsc[int(gpsidxc):] *= 0 - # zero out ringdown wf before gate - if m.current_params['zero_before_gate']: - tsp[:int(gpsidxp)] *= 0 - tsc[:int(gpsidxc)] *= 0 - tsp = tsp.to_frequencyseries() - tsc = tsc.to_frequencyseries() - wf[d] = (tsp, tsc) - wfs.append(wf) - - # combine into a single waveform - combine = {} - for det in self.data: - # get max waveform length - mlenp = max([len(x[det][0]) for x in wfs]) - mlenc = max([len(x[det][1]) for x in wfs]) - mlen = max([mlenp, mlenc]) - # resize plus and cross - [x[det][0].resize(mlen) for x in wfs] - [x[det][1].resize(mlen) for x in wfs] - # combine waveforms - combine[det] = (sum([x[det][0] for x in wfs]), sum([x[det][1] for x in wfs])) - - self._current_wfs = combine - return self._loglikelihood() diff --git a/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-mod-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-mod-checkpoint.py deleted file mode 100644 index 91710f2c8df..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/gated_gaussian_noise-mod-checkpoint.py +++ /dev/null @@ -1,911 +0,0 @@ -# Copyright (C) 2020 Collin Capano and Shilpa Kastha -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - -"""This module provides model classes that assume the noise is Gaussian and -introduces a gate to remove given times from the data, using the inpainting -method to fill the removed part such that it does not enter the likelihood. -""" - -from abc import abstractmethod -import logging -import numpy -import scipy -from scipy import special - -from pycbc.waveform import (NoWaveformError, FailedWaveformError) -from pycbc.types import FrequencySeries -from pycbc.detector import Detector -from pycbc.pnutils import hybrid_meco_frequency -from pycbc.psd import interpolate -from pycbc import types -from pycbc.waveform.utils import time_from_frequencyseries -from pycbc.waveform import generator -from pycbc.filter import highpass -from .gaussian_noise import (BaseGaussianNoise, create_waveform_generator) -from .base_data import BaseDataModel -from .data_utils import fd_data_from_strain_dict - - -class BaseGatedGaussian(BaseGaussianNoise): - r"""Base model for gated gaussian. - - Provides additional routines for applying a time-domain gate to data. - See :py:class:`GatedGaussianNoise` for more details. - """ - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, highpass_waveforms=False, **kwargs): - # we'll want the time-domain data, so store that - self._td_data = {} - # cache the current projection for debugging - self.current_proj = {} - self.current_nproj = {} - # cache the overwhitened data - self._overwhitened_data = {} - # cache the current gated data - self._gated_data = {} - # cache covariance matrix and normalization terms - self._Rss = {} - self._cov = {} - self._lognorm = {} - # cache samples and linear regression for determinant extrapolation - self._cov_samples = {} - self._cov_regressions = {} - # highpass waveforms with the given frequency - self.highpass_waveforms = highpass_waveforms - if self.highpass_waveforms: - logging.info("Will highpass waveforms at %f Hz", - highpass_waveforms) - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - - @classmethod - def from_config(cls, cp, data_section='data', data=None, psds=None, - **kwargs): - """Adds highpass filtering to keyword arguments based on config file. - """ - if cp.has_option(data_section, 'strain-high-pass') and \ - 'highpass_waveforms' not in kwargs: - kwargs['highpass_waveforms'] = float(cp.get(data_section, - 'strain-high-pass')) - return super().from_config(cp, data_section=data_section, - data=data, psds=psds, - **kwargs) - - @BaseDataModel.data.setter - def data(self, data): - """Store a copy of the FD and TD data.""" - BaseDataModel.data.fset(self, data) - # store the td version - self._td_data = {det: d.to_timeseries() for det, d in data.items()} - - @property - def td_data(self): - """The data in the time domain.""" - return self._td_data - - @BaseGaussianNoise.psds.setter - def psds(self, psds): - """Sets the psds, and calculates the weight and norm from them. - The data and the low and high frequency cutoffs must be set first. - """ - # check that the data has been set - if self._data is None: - raise ValueError("No data set") - if self._f_lower is None: - raise ValueError("low frequency cutoff not set") - if self._f_upper is None: - raise ValueError("high frequency cutoff not set") - # make sure the relevant caches are cleared - self._psds.clear() - self._invpsds.clear() - self._gated_data.clear() - # store the psds - for det, d in self._data.items(): - if psds is None: - # No psd means assume white PSD - p = FrequencySeries(numpy.ones(int(self._N/2+1)), - delta_f=d.delta_f) - else: - # copy for storage - p = psds[det].copy() - self._psds[det] = p - # we'll store the weight to apply to the inner product - invp = 1./p - self._invpsds[det] = invp - # store the autocorrelation function and covariance matrix for each detector - Rss = p.astype(types.complex_same_precision_as(p)).to_timeseries() - cov = scipy.linalg.toeplitz(Rss/2) # full covariance matrix - self._Rss[det] = Rss - self._cov[det] = cov - # calculate and store the linear regressions to extrapolate determinant values - if self.normalize: - samples, fit = self.logdet_fit(cov, p) - self._cov_samples[det] = samples - self._cov_regressions[det] = fit - self._overwhitened_data = self.whiten(self.data, 2, inplace=False) - - def logdet_fit(self, cov, p): - """Construct a linear regression from a sample of truncated covariance matrices. - """ - # initialize lists for matrix sizes and determinants - sample_sizes = [] - sample_dets = [] - # set sizes of sample matrices; ensure exact calculations are only on small matrices - s = cov.shape[0] - max_size = 8192 - if s > max_size: - sample_sizes = [s, max_size, max_size//2, max_size//4] - else: - sample_sizes = [s, s//2, s//4, s//8] - for i in sample_sizes: - # calculate logdet of the full matrix using circulant eigenvalue approximation - if i == s: - ld = 2*numpy.log(p/(2*p.delta_t)).sum() - sample_dets.append(ld) - # generate three more sample matrices using exact calculations - else: - gate_size = s - i - start = (s - gate_size)//2 - end = start + gate_size - tc = numpy.delete(numpy.delete(cov, slice(start, end), 0), slice(start, end), 1) - ld = numpy.linalg.slogdet(tc)[1] - sample_dets.append(ld) - # generate a linear regression using the four points (size, logdet) - x = numpy.vstack([sample_sizes, numpy.ones(len(sample_sizes))]).T - m, b = numpy.linalg.lstsq(x, sample_dets, rcond=None)[0] - return (sample_sizes, sample_dets), (m, b) - - def gate_indices(self, det): - """Calculate the indices corresponding to start and end of gate. - """ - # get time series start and delta_t - ts = self._Rss[det] - start_time = self.td_data[det].start_time - delta_t = ts.delta_t - # get gate start and length from get_gate_times - gate_start, gate_length = self.get_gate_times()[det] - # convert to indices - lindex = int(float(gate_start - start_time)/delta_t) - rindex = lindex + int(gate_length / delta_t) - lindex = lindex if lindex >= 0 else 0 - rindex = rindex if rindex <= len(ts) else len(ts) - return lindex, rindex - - ### trying out exact calculations for static gate sizes ### - - def det_lognorm_linext(self, det, start_index=None, end_index=None): - """Calculate the normalization term from the truncated covariance matrix. - Determinant is estimated using a linear fit to logdet vs truncated matrix size - """ - if not self.normalize: - return 0 - try: - # check if the key already exists; if so, return its value - lognorm = self._lognorm[(det, start_index, end_index)] - except KeyError: - # if not, extrapolate the normalization term - start_index, end_index = self.gate_indices(det) - # get the size of the matrix - cov = self._cov[det] - n = cov.shape[0] - trunc_size = n - (end_index - start_index) - # call the linear regression - m, b = self._cov_regressions[det] - # extrapolate from linear fit - ld = m*trunc_size + b - lognorm = -0.5*(numpy.log(2*numpy.pi)*trunc_size + ld) # full normalization term - # cache the result - self._lognorm[(det, start_index, end_index)] = lognorm - return lognorm - - def det_lognorm(self, det, start_index=None, end_index=None): - """Calculate the normalization term from the truncated covariance matrix. - Determinant is calculated exactly using LU factorization - """ - if not self.normalize: - return 0 - # call the cached value if possible - cov = self._cov[det] - try: - full = cov.shape[0] - if start_index == None or end_index == None: - trunc_size = full - else: - trunc_size = full - (end_index - start_index) - lognorm = self._lognorm[(det, trunc_size)] - # if not, do the full calculation - except KeyError: - start, end = self.gate_indices(det) - # truncate the matrix and calculate the norm term - trunc = numpy.delete(numpy.delete(cov, slice(start, end), 0), slice(start, end), 1) - ld = numpy.linalg.slogdet(trunc)[1] - lognorm = -0.5*(numpy.log(2*numpy.pi)*trunc_size + ld) # full normalization term - # cache the result - self._lognorm[(det, trunc_size)] = lognorm - return lognorm - - @property - def normalize(self): - """Determines if the loglikelihood includes the normalization term. - """ - return self._normalize - - ### This is called before psds, so doing direct calls to self._cov, self._psds, etc. throws an error for now ### - - @normalize.setter - def normalize(self, normalize): - """Clears the current stats if the normalization state is changed. - """ - self._normalize = normalize - # set covariance det linear regression iff normalize is set to true and the respective dicts are empty - # if self.normalize and self._cov_samples == {} and self._cov_regressions == {}: - # for det, d in self._data.items(): - # cov = self._cov[det] - # samples, fit = self.logdet_fit(cov, p) - # self._cov_samples[det] = samples - # self._cov_regressions[det] = fit - - @staticmethod - def _nowaveform_logl(): - """Convenience function to set logl values if no waveform generated. - """ - return -numpy.inf - - def _loglr(self): - r"""Computes the log likelihood ratio. - Returns - ------- - float - The value of the log likelihood ratio evaluated at the given point. - """ - return self._loglikelihood() - self._lognl() - - def whiten(self, data, whiten, inplace=False): - """Whitens the given data. - - Parameters - ---------- - data : dict - Dictionary of detector names -> FrequencySeries. - whiten : {0, 1, 2} - Integer indicating what level of whitening to apply. Levels are: - 0: no whitening; 1: whiten; 2: overwhiten. - inplace : bool, optional - If True, modify the data in place. Otherwise, a copy will be - created for whitening. - - - Returns - ------- - dict : - Dictionary of FrequencySeries after the requested whitening has - been applied. - """ - if not inplace: - data = {det: d.copy() for det, d in data.items()} - if whiten: - for det, dtilde in data.items(): - invpsd = self._invpsds[det] - if whiten == 1: - dtilde *= invpsd**0.5 - elif whiten == 2: - dtilde *= invpsd - else: - raise ValueError("whiten must be either 0, 1, or 2") - return data - - def get_waveforms(self): - """The waveforms generated using the current parameters. - - If the waveforms haven't been generated yet, they will be generated, - resized to the same length as the data, and cached. If the - ``highpass_waveforms`` attribute is set, a highpass filter will - also be applied to the waveforms. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - if self._current_wfs is None: - params = self.current_params - wfs = self.waveform_generator.generate(**params) - for det, h in wfs.items(): - # make the same length as the data - h.resize(len(self.data[det])) - # apply high pass - if self.highpass_waveforms: - h = highpass( - h.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - wfs[det] = h - self._current_wfs = wfs - return self._current_wfs - - @abstractmethod - def get_gated_waveforms(self): - """Generates and gates waveforms using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - pass - - def get_residuals(self): - """Generates the residuals ``d-h`` using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - wfs = self.get_waveforms() - out = {} - for det, h in wfs.items(): - d = self.data[det] - out[det] = d - h - return out - - def get_data(self): - """Return a copy of the data. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - return {det: d.copy() for det, d in self.data.items()} - - def get_gated_data(self): - """Return a copy of the gated data. - - The gated data will be cached for faster retrieval. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - gate_times = self.get_gate_times() - out = {} - for det, d in self.td_data.items(): - # make sure the cache at least has the detectors in it - try: - cache = self._gated_data[det] - except KeyError: - cache = self._gated_data[det] = {} - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - try: - dtilde = cache[gatestartdelay, dgatedelay] - except KeyError: - # doesn't exist yet, or the gate times changed - cache.clear() - d = d.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - dtilde = d.to_frequencyseries() - # save for next time - cache[gatestartdelay, dgatedelay] = dtilde - out[det] = dtilde - return out - - def get_gate_times(self): - """Gets the time to apply a gate based on the current sky position. - - If the parameter ``gatefunc`` is set to ``'hmeco'``, the gate times - will be calculated based on the hybrid MECO of the given set of - parameters; see ``get_gate_times_hmeco`` for details. Otherwise, the - gate times will just be retrieved from the ``t_gate_start`` and - ``t_gate_end`` parameters. - - .. warning:: - Since the normalization of the likelihood is currently not - being calculated, it is recommended that you do not use - ``gatefunc``, instead using fixed gate times. - - Returns - ------- - dict : - Dictionary of detector names -> (gate start, gate width) - """ - params = self.current_params - try: - gatefunc = self.current_params['gatefunc'] - except KeyError: - gatefunc = None - if gatefunc == 'hmeco': - return self.get_gate_times_hmeco() - # gate input for ringdown analysis which consideres a start time - # and an end time - gatestart = params['t_gate_start'] - gateend = params['t_gate_end'] - # we'll need the sky location for determining time shifts - ra = self.current_params['ra'] - dec = self.current_params['dec'] - gatetimes = {} - for det in self._invpsds: - thisdet = Detector(det) - # account for the time delay between the waveforms of the - # different detectors - gatestartdelay = gatestart + thisdet.time_delay_from_earth_center( - ra, dec, gatestart) - gateenddelay = gateend + thisdet.time_delay_from_earth_center( - ra, dec, gateend) - dgatedelay = gateenddelay - gatestartdelay - gatetimes[det] = (gatestartdelay, dgatedelay) - return gatetimes - - def get_gate_times_hmeco(self): - """Gets the time to apply a gate based on the current sky position. - Returns - ------- - dict : - Dictionary of detector names -> (gate start, gate width) - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get waveform parameters - params = self.current_params - spin1 = params['spin1z'] - spin2 = params['spin2z'] - # gate input for ringdown analysis which consideres a start time - # and an end time - dgate = params['gate_window'] - meco_f = hybrid_meco_frequency(params['mass1'], params['mass2'], - spin1, spin2) - # figure out the gate times - gatetimes = {} - for det, h in wfs.items(): - invpsd = self._invpsds[det] - h.resize(len(invpsd)) - ht = h.to_timeseries() - f_low = int((self._f_lower[det]+1)/h.delta_f) - sample_freqs = h.sample_frequencies[f_low:].numpy() - f_idx = numpy.where(sample_freqs <= meco_f)[0][-1] - # find time corresponding to meco frequency - t_from_freq = time_from_frequencyseries( - h[f_low:], sample_frequencies=sample_freqs) - if t_from_freq[f_idx] > 0: - gatestartdelay = t_from_freq[f_idx] + float(t_from_freq.epoch) - else: - gatestartdelay = t_from_freq[f_idx] + ht.sample_times[-1] - gatestartdelay = min(gatestartdelay, params['t_gate_start']) - gatetimes[det] = (gatestartdelay, dgate) - return gatetimes - - def _lognl(self): - """Calculates the log of the noise likelihood. - """ - # clear variables - lognl = 0. - self._det_lognls.clear() - # get the times of the gates - gate_times = self.get_gate_times() - self.current_nproj.clear() - for det, invpsd in self._invpsds.items(): - start_index, end_index = self.gate_indices(det) - norm = self.det_lognorm_linext(det, start_index, end_index) # linear estimation - #norm = self.det_lognorm(det, start_index, end_index) # exact calculation - gatestartdelay, dgatedelay = gate_times[det] - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # gate the data - data = self.td_data[det] - gated_dt = data.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - self.current_nproj[det] = (gated_dt.proj, gated_dt.projslc) - # convert to the frequency series - gated_d = gated_dt.to_frequencyseries() - # overwhiten - gated_d *= invpsd - d = self.data[det] - # inner product - ip = 4 * invpsd.delta_f * d[slc].inner(gated_d[slc]).real # - dd = norm - 0.5*ip - # store - self._det_lognls[det] = dd - lognl += dd - return float(lognl) - - def det_lognl(self, det): - # make sure lognl has been called - _ = self._trytoget('lognl', self._lognl) - # the det_lognls dict should have been updated, so can call it now - return self._det_lognls[det] - - @staticmethod - def _fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict): - """Wrapper around :py:func:`data_utils.fd_data_from_strain_dict`. - - Ensures that if the PSD is estimated from data, the inverse spectrum - truncation uses a Hann window, and that the low frequency cutoff is - zero. - """ - if opts.psd_inverse_length and opts.invpsd_trunc_method is None: - # make sure invpsd truncation is set to hanning - logging.info("Using Hann window to truncate inverse PSD") - opts.invpsd_trunc_method = 'hann' - lfs = None - if opts.psd_estimation: - # make sure low frequency cutoff is zero - logging.info("Setting low frequency cutoff of PSD to 0") - lfs = opts.low_frequency_cutoff.copy() - opts.low_frequency_cutoff = {d: 0. for d in lfs} - out = fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict) - # set back - if lfs is not None: - opts.low_frequency_cutoff = lfs - return out - - def write_metadata(self, fp, group=None): - """Adds writing the psds, and analyzed detectors. - - The analyzed detectors, their analysis segments, and the segments - used for psd estimation are written as - ``analyzed_detectors``, ``{{detector}}_analysis_segment``, and - ``{{detector}}_psd_segment``, respectively. These are either written - to the specified ``group``'s attrs, or to the top level attrs if - ``group`` is None. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - """ - BaseDataModel.write_metadata(self, fp) - attrs = fp.getattrs(group=group) - # write the analyzed detectors and times - attrs['analyzed_detectors'] = self.detectors # store fitting values here - for det, data in self.data.items(): - key = '{}_analysis_segment'.format(det) - attrs[key] = [float(data.start_time), float(data.end_time)] - # store covariance determinant extrapolation information (checkpoint) - if self.normalize: - attrs['{}_cov_sample'.format(det)] = self._cov_samples[det] - attrs['{}_cov_regression'.format(det)] = self._cov_regressions[det] - if self._psds is not None and not self.no_save_data: - fp.write_psd(self._psds, group=group) - # write the times used for psd estimation (if they were provided) - for det in self.psd_segments: - key = '{}_psd_segment'.format(det) - attrs[key] = list(map(float, self.psd_segments[det])) - # save the frequency cutoffs - for det in self.detectors: - attrs['{}_likelihood_low_freq'.format(det)] = self._f_lower[det] - if self._f_upper[det] is not None: - attrs['{}_likelihood_high_freq'.format(det)] = \ - self._f_upper[det] - - -class GatedGaussianNoise(BaseGatedGaussian): - r"""Model that applies a time domain gate, assuming stationary Gaussian - noise. - - The gate start and end times are set by providing ``t_gate_start`` and - ``t_gate_end`` parameters, respectively. This will cause the gated times - to be excised from the analysis. For more details on the likelihood - function and its derivation, see - `arXiv:2105.05238 `_. - - .. warning:: - The normalization of the likelihood depends on the gate times. However, - at the moment, the normalization is not calculated, as it depends on - the determinant of the truncated covariance matrix (see Eq. 4 of - arXiv:2105.05238). For this reason it is recommended that you only - use this model for fixed gate times. - - """ - name = 'gated_gaussian_noise' - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, **kwargs): - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - # create the waveform generator - self.waveform_generator = create_waveform_generator( - self.variable_params, self.data, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - gates=self.gates, **self.static_params) - - @property - def _extra_stats(self): - """No extra stats are stored.""" - return [] - - ### This needs to be called after running model.update() for changes to take effect ### - - def _loglikelihood(self): - r"""Computes the log likelihood after removing the power within the - given time window, - - .. math:: - \log p(d|\Theta) = -\frac{1}{2} \sum_i - \left< d_i - h_i(\Theta) | d_i - h_i(\Theta) \right>, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood. - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get the times of the gates - gate_times = self.get_gate_times() - logl = 0. - self.current_proj.clear() - for det, h in wfs.items(): - invpsd = self._invpsds[det] - start_index, end_index = self.gate_indices(det) - #norm = self.det_lognorm(det, start_index, end_index) - norm = self.det_lognorm_linext(det, start_index, end_index) - gatestartdelay, dgatedelay = gate_times[det] - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # calculate the residual - data = self.td_data[det] - ht = h.to_timeseries() - res = data - ht - rtilde = res.to_frequencyseries() - gated_res = res.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - self.current_proj[det] = (gated_res.proj, gated_res.projslc) - gated_rtilde = gated_res.to_frequencyseries() - # overwhiten - gated_rtilde *= invpsd - rr = 4 * invpsd.delta_f * rtilde[slc].inner(gated_rtilde[slc]).real - logl += norm - 0.5*rr - return float(logl) - - def get_gated_waveforms(self): - wfs = self.get_waveforms() - gate_times = self.get_gate_times() - out = {} - for det, h in wfs.items(): - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - ht = h.to_timeseries() - ht = ht.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=False, - invpsd=invpsd, method='paint') - h = ht.to_frequencyseries() - out[det] = h - return out - - def get_gated_residuals(self): - """Generates the gated residuals ``d-h`` using the current parameters. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - params = self.current_params - wfs = self.waveform_generator.generate(**params) - gate_times = self.get_gate_times() - out = {} - for det, h in wfs.items(): - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - data = self.td_data[det] - ht = h.to_timeseries() - res = data - ht - res = res.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=True, - invpsd=invpsd, method='paint') - res = res.to_frequencyseries() - out[det] = res - return out - - -class GatedGaussianMargPol(BaseGatedGaussian): - r"""Gated gaussian model with numerical marginalization over polarization. - - This implements the GatedGaussian likelihood with an explicit numerical - marginalization over polarization angle. This is accomplished using - a fixed set of integration points distribution uniformation between - 0 and 2pi. By default, 1000 integration points are used. - The 'polarization_samples' argument can be passed to set an alternate - number of integration points. - """ - name = 'gated_gaussian_margpol' - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, - polarization_samples=1000, **kwargs): - # set up the boiler-plate attributes - super().__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - # the polarization parameters - self.polarization_samples = polarization_samples - self.pol = numpy.linspace(0, 2*numpy.pi, self.polarization_samples) - self.dets = {} - # create the waveform generator - self.waveform_generator = create_waveform_generator( - self.variable_params, self.data, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - generator_class=generator.FDomainDetFrameTwoPolGenerator, - **self.static_params) - - def get_waveforms(self): - if self._current_wfs is None: - params = self.current_params - wfs = self.waveform_generator.generate(**params) - for det, (hp, hc) in wfs.items(): - # make the same length as the data - hp.resize(len(self.data[det])) - hc.resize(len(self.data[det])) - # apply high pass - if self.highpass_waveforms: - hp = highpass( - hp.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - hc = highpass( - hc.to_timeseries(), - frequency=self.highpass_waveforms).to_frequencyseries() - wfs[det] = (hp, hc) - self._current_wfs = wfs - return self._current_wfs - - def get_gated_waveforms(self): - wfs = self.get_waveforms() - gate_times = self.get_gate_times() - out = {} - for det in wfs: - invpsd = self._invpsds[det] - gatestartdelay, dgatedelay = gate_times[det] - # the waveforms are a dictionary of (hp, hc) - pols = [] - for h in wfs[det]: - ht = h.to_timeseries() - ht = ht.gate(gatestartdelay + dgatedelay/2, - window=dgatedelay/2, copy=False, - invpsd=invpsd, method='paint') - h = ht.to_frequencyseries() - pols.append(h) - out[det] = tuple(pols) - return out - - @property - def _extra_stats(self): - """Adds the maxL polarization and corresponding likelihood.""" - return ['maxl_polarization', 'maxl_logl'] - - def _loglikelihood(self): - r"""Computes the log likelihood after removing the power within the - given time window, - - .. math:: - \log p(d|\Theta) = -\frac{1}{2} \sum_i - \left< d_i - h_i(\Theta) | d_i - h_i(\Theta) \right>, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood. - """ - # generate the template waveform - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_logl() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_logl() - raise e - # get the gated waveforms and data - gated_wfs = self.get_gated_waveforms() - gated_data = self.get_gated_data() - # cycle over - loglr = 0. - lognl = 0. - for det, (hp, hc) in wfs.items(): - # get the antenna patterns - if det not in self.dets: - self.dets[det] = Detector(det) - fp, fc = self.dets[det].antenna_pattern(self.current_params['ra'], - self.current_params['dec'], - self.pol, - self.current_params['tc']) - start_index, end_index = self.gate_indices(det) - # norm = self.det_lognorm(det, start_index, end_index) - norm = self.det_lognorm_linext(det, start_index, end_index) - # we always filter the entire segment starting from kmin, since the - # gated series may have high frequency components - slc = slice(self._kmin[det], self._kmax[det]) - # get the gated values - gated_hp, gated_hc = gated_wfs[det] - gated_d = gated_data[det] - # we'll overwhiten the ungated data and waveforms for computing - # inner products - d = self._overwhitened_data[det] - # overwhiten the hp and hc - # we'll do this in place for computational efficiency, but as a - # result we'll clear the current waveforms cache so a repeated call - # to get_waveforms does not return the overwhitened versions - self._current_wfs = None - invpsd = self._invpsds[det] - hp *= invpsd - hc *= invpsd - # get the various gated inner products - hpd = hp[slc].inner(gated_d[slc]).real # - hcd = hc[slc].inner(gated_d[slc]).real # - dhp = d[slc].inner(gated_hp[slc]).real # - dhc = d[slc].inner(gated_hc[slc]).real # - hphp = hp[slc].inner(gated_hp[slc]).real # - hchc = hc[slc].inner(gated_hc[slc]).real # - hphc = hp[slc].inner(gated_hc[slc]).real # - hchp = hc[slc].inner(gated_hp[slc]).real # - dd = d[slc].inner(gated_d[slc]).real # - # since the antenna patterns are real, - # /2 + /2 = fp*(/2 + /2) - # + fc*(/2 + /2) - hd = fp*(hpd + dhp) + fc*(hcd + dhc) - # /2 = /2 - # = fp*fp*/2 + fc*fc*/2 - # + fp*fc*/2 + fc*fp*/2 - hh = fp*fp*hphp + fc*fc*hchc + fp*fc*(hphc + hchp) - # sum up; note that the factor is 2df instead of 4df to account - # for the factor of 1/2 - loglr += norm + 2*invpsd.delta_f*(hd - hh) - lognl += -2 * invpsd.delta_f * dd - # store the maxl polarization - idx = loglr.argmax() - setattr(self._current_stats, 'maxl_polarization', self.pol[idx]) - setattr(self._current_stats, 'maxl_logl', loglr[idx] + lognl) - # compute the marginalized log likelihood - marglogl = special.logsumexp(loglr) + lognl - numpy.log(len(self.pol)) - return float(marglogl) diff --git a/pycbc/inference/models/.ipynb_checkpoints/gaussian_noise-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/gaussian_noise-checkpoint.py deleted file mode 100644 index 9b018a85cdb..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/gaussian_noise-checkpoint.py +++ /dev/null @@ -1,1202 +0,0 @@ -# Copyright (C) 2018 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - -"""This module provides model classes that assume the noise is Gaussian. -""" - -import logging -import shlex -from abc import ABCMeta -import numpy - -from pycbc import filter as pyfilter -from pycbc.waveform import (NoWaveformError, FailedWaveformError) -from pycbc.waveform import generator -from pycbc.types import FrequencySeries -from pycbc.strain import gates_from_cli -from pycbc.strain.calibration import Recalibrate -from pycbc.inject import InjectionSet -from pycbc.io import FieldArray -from pycbc.types.optparse import MultiDetOptionAction - -from .base import ModelStats -from .base_data import BaseDataModel -from .data_utils import (data_opts_from_config, data_from_cli, - fd_data_from_strain_dict, gate_overwhitened_data) - - -class BaseGaussianNoise(BaseDataModel, metaclass=ABCMeta): - r"""Model for analyzing GW data with assuming a wide-sense stationary - Gaussian noise model. - - This model will load gravitational wave data and calculate the log noise - likelihood ``_lognl`` and normalization. It also implements the - ``_loglikelihood`` function as the sum of the log likelihood ratio and the - ``lognl``. It does not implement a log likelihood ratio function - ``_loglr``, however, since that can differ depending on the signal model. - Models that analyze GW data assuming it is stationary Gaussian should - therefore inherit from this class and implement their own ``_loglr`` - function. - - For more details on the inner product used, the log likelihood of the - noise, and the normalization factor, see :py:class:`GaussianNoise`. - - Parameters - ---------- - variable_params : (tuple of) string(s) - A tuple of parameter names that will be varied. - data : dict - A dictionary of data, in which the keys are the detector names and the - values are the data (assumed to be unwhitened). All data must have the - same frequency resolution. - low_frequency_cutoff : dict - A dictionary of starting frequencies, in which the keys are the - detector names and the values are the starting frequencies for the - respective detectors to be used for computing inner products. - psds : dict, optional - A dictionary of FrequencySeries keyed by the detector names. The - dictionary must have a psd for each detector specified in the data - dictionary. If provided, the inner products in each detector will be - weighted by 1/psd of that detector. - high_frequency_cutoff : dict, optional - A dictionary of ending frequencies, in which the keys are the - detector names and the values are the ending frequencies for the - respective detectors to be used for computing inner products. If not - provided, the minimum of the largest frequency stored in the data - and a given waveform will be used. - normalize : bool, optional - If True, the normalization factor :math:`alpha` will be included in the - log likelihood. See :py:class:`GaussianNoise` for details. Default is - to not include it. - static_params : dict, optional - A dictionary of parameter names -> values to keep fixed. - ignore_failed_waveforms : bool, optional - If the waveform generator raises an error when it tries to generate, - treat the point as having zero likelihood. This allows the parameter - estimation to continue. Otherwise, an error will be raised, stopping - the run. Default is False. - \**kwargs : - All other keyword arguments are passed to ``BaseDataModel``. - - Attributes - ---------- - ignore_failed_waveforms : bool - If True, points in parameter space that cause waveform generation to - fail (i.e., they raise a ``FailedWaveformError``) will be treated as - points with zero likelihood. Otherwise, such points will cause the - model to raise a ``FailedWaveformError``. - """ - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, ignore_failed_waveforms=False, - no_save_data=False, - **kwargs): - # set up the boiler-plate attributes - super(BaseGaussianNoise, self).__init__(variable_params, data, - static_params=static_params, - no_save_data=no_save_data, - **kwargs) - self.ignore_failed_waveforms = ignore_failed_waveforms - self.no_save_data = no_save_data - # check if low frequency cutoff has been provided for every IFO with - # data - for ifo in self.data: - if low_frequency_cutoff[ifo] is None: - raise ValueError( - "A low-frequency-cutoff must be provided for every " - "detector for which data has been provided. If " - "loading the model settings from " - "a config file, please provide " - "`{DETECTOR}:low-frequency-cutoff` options for " - "every detector in the `[model]` section, where " - "`{DETECTOR} is the name of the detector," - "or provide a single low-frequency-cutoff option" - "which will be used for all detectors") - - # check that the data sets all have the same delta fs and delta ts - dts = numpy.array([d.delta_t for d in self.data.values()]) - dfs = numpy.array([d.delta_f for d in self.data.values()]) - if all(dts == dts[0]) and all(dfs == dfs[0]): - self.all_ifodata_same_rate_length = True - else: - self.all_ifodata_same_rate_length = False - logging.info( - "You are using different data segment lengths or " - "sampling rates for different IFOs") - - # store the number of samples in the time domain - self._N = {} - for (det, d) in self._data.items(): - self._N[det] = int(1./(d.delta_f*d.delta_t)) - - # set lower/upper frequency cutoff - if high_frequency_cutoff is None: - high_frequency_cutoff = {ifo: None for ifo in self.data} - self._f_upper = high_frequency_cutoff - self._f_lower = low_frequency_cutoff - - # Set the cutoff indices - self._kmin = {} - self._kmax = {} - - for (det, d) in self._data.items(): - kmin, kmax = pyfilter.get_cutoff_indices(self._f_lower[det], - self._f_upper[det], - d.delta_f, self._N[det]) - self._kmin[det] = kmin - self._kmax[det] = kmax - - # store the psd segments - self._psd_segments = {} - if psds is not None: - self.set_psd_segments(psds) - - # store the psds and calculate the inner product weight - self._psds = {} - self._invpsds = {} - self._weight = {} - self._lognorm = {} - self._det_lognls = {} - self._whitened_data = {} - - # set the normalization state - self._normalize = False - self.normalize = normalize - # store the psds and whiten the data - self.psds = psds - - # attribute for storing the current waveforms - self._current_wfs = None - - @property - def high_frequency_cutoff(self): - """The high frequency cutoff of the inner product.""" - return self._f_upper - - @property - def low_frequency_cutoff(self): - """The low frequency cutoff of the inner product.""" - return self._f_lower - - @property - def kmin(self): - """Dictionary of starting indices for the inner product. - - This is determined from the lower frequency cutoff and the ``delta_f`` - of the data using - :py:func:`pycbc.filter.matchedfilter.get_cutoff_indices`. - """ - return self._kmin - - @property - def kmax(self): - """Dictionary of ending indices for the inner product. - - This is determined from the high frequency cutoff and the ``delta_f`` - of the data using - :py:func:`pycbc.filter.matchedfilter.get_cutoff_indices`. If no high - frequency cutoff was provided, this will be the indice corresponding to - the Nyquist frequency. - """ - return self._kmax - - @property - def psds(self): - """Dictionary of detectors -> PSD frequency series. - - If no PSD was provided for a detector, this will just be a frequency - series of ones. - """ - return self._psds - - @psds.setter - def psds(self, psds): - """Sets the psds, and calculates the weight and norm from them. - - The data and the low and high frequency cutoffs must be set first. - """ - # check that the data has been set - if self._data is None: - raise ValueError("No data set") - if self._f_lower is None: - raise ValueError("low frequency cutoff not set") - if self._f_upper is None: - raise ValueError("high frequency cutoff not set") - # make sure the relevant caches are cleared - self._psds.clear() - self._invpsds.clear() - self._weight.clear() - self._lognorm.clear() - self._det_lognls.clear() - self._whitened_data.clear() - for det, d in self._data.items(): - if psds is None: - # No psd means assume white PSD - p = FrequencySeries(numpy.ones(int(self._N[det]/2+1)), - delta_f=d.delta_f) - else: - # copy for storage - p = psds[det].copy() - self._psds[det] = p - # we'll store the weight to apply to the inner product - # only set weight in band we will analyze - kmin = self._kmin[det] - kmax = self._kmax[det] - invp = FrequencySeries(numpy.zeros(len(p)), delta_f=p.delta_f) - invp[kmin:kmax] = 1./p[kmin:kmax] - self._invpsds[det] = invp - self._weight[det] = numpy.sqrt(4 * invp.delta_f * invp) - self._whitened_data[det] = d.copy() - self._whitened_data[det] *= self._weight[det] - # set the lognl and lognorm; we'll get this by just calling lognl - _ = self.lognl - - @property - def psd_segments(self): - """Dictionary giving times used for PSD estimation for each detector. - - If a detector's PSD was not estimated from data, or the segment wasn't - provided, that detector will not be in the dictionary. - """ - return self._psd_segments - - def set_psd_segments(self, psds): - """Sets the PSD segments from a dictionary of PSDs. - - This attempts to get the PSD segment from a ``psd_segment`` attribute - of each detector's PSD frequency series. If that attribute isn't set, - then that detector is not added to the dictionary of PSD segments. - - Parameters - ---------- - psds : dict - Dictionary of detector name -> PSD frequency series. The segment - used for each PSD will try to be retrieved from the PSD's - ``.psd_segment`` attribute. - """ - for det, p in psds.items(): - try: - self._psd_segments[det] = p.psd_segment - except AttributeError: - continue - - @property - def weight(self): - r"""Dictionary of detectors -> frequency series of inner-product - weights. - - The weights are :math:`\sqrt{4 \Delta f / S_n(f)}`. This is set when - the PSDs are set. - """ - return self._weight - - @property - def whitened_data(self): - r"""Dictionary of detectors -> whitened data frequency series. - - The whitened data is the data multiplied by the inner-product weight. - Note that this includes the :math:`\sqrt{4 \Delta f}` factor. This - is set when the PSDs are set. - """ - return self._whitened_data - - def det_lognorm(self, det): - """The log of the likelihood normalization in the given detector. - - If ``self.normalize`` is False, will just return 0. - """ - if not self.normalize: - return 0. - try: - return self._lognorm[det] - except KeyError: - # hasn't been calculated yet - p = self._psds[det] - dt = self._whitened_data[det].delta_t - kmin = self._kmin[det] - kmax = self._kmax[det] - lognorm = -float(self._N[det]*numpy.log(numpy.pi*self._N[det]*dt)/2. - + numpy.log(p[kmin:kmax]).sum()) - self._lognorm[det] = lognorm - return self._lognorm[det] - - @property - def normalize(self): - """Determines if the loglikelihood includes the normalization term. - """ - return self._normalize - - @normalize.setter - def normalize(self, normalize): - """Clears the current stats if the normalization state is changed. - """ - if normalize != self._normalize: - self._current_stats = ModelStats() - self._lognorm.clear() - self._det_lognls.clear() - self._normalize = normalize - - @property - def lognorm(self): - """The log of the normalization of the log likelihood.""" - return sum(self.det_lognorm(det) for det in self._data) - - def det_lognl(self, det): - r"""Returns the log likelihood of the noise in the given detector: - - .. math:: - - \log p(d_i|n_i) = \log \alpha_i - - \frac{1}{2} \left. - - - Parameters - ---------- - det : str - The name of the detector. - - Returns - ------- - float : - The log likelihood of the noise in the requested detector. - """ - try: - return self._det_lognls[det] - except KeyError: - # hasn't been calculated yet; calculate & store - kmin = self._kmin[det] - kmax = self._kmax[det] - d = self._whitened_data[det] - lognorm = self.det_lognorm(det) - lognl = lognorm - 0.5 * d[kmin:kmax].inner(d[kmin:kmax]).real - self._det_lognls[det] = lognl - return self._det_lognls[det] - - def _lognl(self): - """Computes the log likelihood assuming the data is noise. - - Since this is a constant for Gaussian noise, this is only computed once - then stored. - """ - return sum(self.det_lognl(det) for det in self._data) - - def update(self, **params): - # update - super().update(**params) - # reset current waveforms - self._current_wfs = None - - def _loglikelihood(self): - r"""Computes the log likelihood of the paramaters, - - .. math:: - - \log p(d|\Theta, h) = \log \alpha -\frac{1}{2}\sum_i - \left, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood evaluated at the given point. - """ - # since the loglr has fewer terms, we'll call that, then just add - # back the noise term that canceled in the log likelihood ratio - return self.loglr + self.lognl - - def write_metadata(self, fp, group=None): - """Adds writing the psds, analyzed detectors, and lognl. - - The analyzed detectors, their analysis segments, and the segments - used for psd estimation are written as - ``analyzed_detectors``, ``{{detector}}_analysis_segment``, and - ``{{detector}}_psd_segment``, respectively. These are either written - to the specified ``group``'s attrs, or to the top level attrs if - ``group`` is None. - - The total and each detector's lognl is written to the sample group's - ``attrs``. If a group is specified, the group name will be prependend - to the lognl labels with ``{group}__``, with any ``/`` in the group - path replaced with ``__``. For example, if group is ``/a/b``, the - ``lognl`` will be written as ``a__b__lognl`` in the sample's group - attrs. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - """ - super().write_metadata(fp, group=group) - attrs = fp.getattrs(group=group) - # write the analyzed detectors and times - attrs['analyzed_detectors'] = self.detectors - for det, data in self.data.items(): - key = '{}_analysis_segment'.format(det) - attrs[key] = [float(data.start_time), float(data.end_time)] - if self._psds is not None and not self.no_save_data: - fp.write_psd(self._psds, group=group) - # write the times used for psd estimation (if they were provided) - for det in self.psd_segments: - key = '{}_psd_segment'.format(det) - attrs[key] = list(map(float, self.psd_segments[det])) - # save the frequency cutoffs - for det in self.detectors: - attrs['{}_likelihood_low_freq'.format(det)] = self._f_lower[det] - if self._f_upper[det] is not None: - attrs['{}_likelihood_high_freq'.format(det)] = \ - self._f_upper[det] - # write the lognl to the samples group attrs - sampattrs = fp.getattrs(group=fp.samples_group) - # if a group is specified, prepend the lognl names with it - if group is None or group == '/': - prefix = '' - else: - prefix = group.replace('/', '__') - if not prefix.endswith('__'): - prefix += '__' - sampattrs['{}lognl'.format(prefix)] = self.lognl - # also save the lognl in each detector - for det in self.detectors: - sampattrs['{}{}_lognl'.format(prefix, det)] = self.det_lognl(det) - - @staticmethod - def _fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict): - """Wrapper around :py:func:`data_utils.fd_data_from_strain_dict`.""" - return fd_data_from_strain_dict(opts, strain_dict, psd_strain_dict) - - @classmethod - def from_config(cls, cp, data_section='data', data=None, psds=None, - **kwargs): - r"""Initializes an instance of this class from the given config file. - - In addition to ``[model]``, a ``data_section`` (default ``[data]``) - must be in the configuration file. The data section specifies settings - for loading data and estimating PSDs. See the `online documentation - `_ for - more details. - - The following options are read from the ``[model]`` section, in - addition to ``name`` (which must be set): - - * ``{{DET}}-low-frequency-cutoff = FLOAT`` : - The low frequency cutoff to use for each detector {{DET}}. A cutoff - must be provided for every detector that may be analyzed (any - additional detectors are ignored). - * ``{{DET}}-high-frequency-cutoff = FLOAT`` : - (Optional) A high frequency cutoff for each detector. If not - provided, the Nyquist frequency is used. - * ``check-for-valid-times =`` : - (Optional) If provided, will check that there are no data quality - flags on during the analysis segment and the segment used for PSD - estimation in each detector. To check for flags, - :py:func:`pycbc.dq.query_flag` is used, with settings pulled from the - ``dq-*`` options in the ``[data]`` section. If a detector has bad - data quality during either the analysis segment or PSD segment, it - will be removed from the analysis. - * ``shift-psd-times-to-valid =`` : - (Optional) If provided, the segment used for PSD estimation will - automatically be shifted left or right until a continous block of - data with no data quality issues can be found. If no block can be - found with a maximum shift of +/- the requested psd segment length, - the detector will not be analyzed. - * ``err-on-missing-detectors =`` : - Raises an error if any detector is removed from the analysis because - a valid time could not be found. Otherwise, a warning is printed - to screen and the detector is removed from the analysis. - * ``normalize =`` : - (Optional) Turn on the normalization factor. - * ``ignore-failed-waveforms =`` : - Sets the ``ignore_failed_waveforms`` attribute. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - data_section : str, optional - The name of the section to load data options from. - \**kwargs : - All additional keyword arguments are passed to the class. Any - provided keyword will override what is in the config file. - """ - # get the injection file, to replace any FROM_INJECTION settings - if 'injection-file' in cp.options('data'): - injection_file = cp.get('data', 'injection-file') - else: - injection_file = None - # update any values that are to be retrieved from the injection - # Note: this does nothing if there are FROM_INJECTION values - get_values_from_injection(cp, injection_file, update_cp=True) - args = cls._init_args_from_config(cp) - # add the injection file - args['injection_file'] = injection_file - # check if normalize is set - if cp.has_option('model', 'normalize'): - args['normalize'] = True - if cp.has_option('model', 'ignore-failed-waveforms'): - args['ignore_failed_waveforms'] = True - if cp.has_option('model', 'no-save-data'): - args['no_save_data'] = True - # get any other keyword arguments provided in the model section - ignore_args = ['name', 'normalize', - 'ignore-failed-waveforms', 'no-save-data'] - for option in cp.options("model"): - if option in ("low-frequency-cutoff", "high-frequency-cutoff"): - ignore_args.append(option) - name = option.replace('-', '_') - args[name] = cp.get_cli_option('model', name, - nargs='+', type=float, - action=MultiDetOptionAction) - - if 'low_frequency_cutoff' not in args: - raise ValueError("low-frequency-cutoff must be provided in the" - " model section, but is not found!") - - # data args - bool_args = ['check-for-valid-times', 'shift-psd-times-to-valid', - 'err-on-missing-detectors'] - data_args = {arg.replace('-', '_'): True for arg in bool_args - if cp.has_option('model', arg)} - ignore_args += bool_args - # load the data - opts = data_opts_from_config(cp, data_section, - args['low_frequency_cutoff']) - if data is None or psds is None: - strain_dict, psd_strain_dict = data_from_cli(opts, **data_args) - # convert to frequency domain and get psds - stilde_dict, psds = cls._fd_data_from_strain_dict( - opts, strain_dict, psd_strain_dict) - # save the psd data segments if the psd was estimated from data - if opts.psd_estimation: - _tdict = psd_strain_dict or strain_dict - for det in psds: - psds[det].psd_segment = (_tdict[det].start_time, - _tdict[det].end_time) - # gate overwhitened if desired - if opts.gate_overwhitened and opts.gate is not None: - stilde_dict = gate_overwhitened_data( - stilde_dict, psds, opts.gate) - data = stilde_dict - args.update({'data': data, 'psds': psds}) - # any extra args - args.update(cls.extra_args_from_config(cp, "model", - skip_args=ignore_args)) - # get ifo-specific instances of calibration model - if cp.has_section('calibration'): - logging.info("Initializing calibration model") - recalib = { - ifo: Recalibrate.from_config(cp, ifo, section='calibration') - for ifo in opts.instruments} - args['recalibration'] = recalib - # get gates for templates - gates = gates_from_cli(opts) - if gates: - args['gates'] = gates - args.update(kwargs) - return cls(**args) - - -class GaussianNoise(BaseGaussianNoise): - r"""Model that assumes data is stationary Gaussian noise. - - With Gaussian noise the log likelihood functions for signal - :math:`\log p(d|\Theta, h)` and for noise :math:`\log p(d|n)` are given by: - - .. math:: - - \log p(d|\Theta, h) &= \log\alpha -\frac{1}{2} \sum_i - \left< d_i - h_i(\Theta) | d_i - h_i(\Theta) \right> \\ - \log p(d|n) &= \log\alpha -\frac{1}{2} \sum_i \left - - where the sum is over the number of detectors, :math:`d_i` is the data in - each detector, and :math:`h_i(\Theta)` is the model signal in each - detector. The (discrete) inner product is given by: - - .. math:: - - \left = 4\Re \Delta f - \sum_{k=k_{\mathrm{min}}}^{k_{\mathrm{max}}} - \frac{\tilde{a}_i^{*}[k] \tilde{b}_i[k]}{S^{(i)}_n[k]}, - - where :math:`\Delta f` is the frequency resolution (given by 1 / the - observation time :math:`T`), :math:`k` is an index over the discretely - sampled frequencies :math:`f = k \Delta_f`, and :math:`S^{(i)}_n[k]` is the - PSD in the given detector. The upper cutoff on the inner product - :math:`k_{\max}` is by default the Nyquist frequency - :math:`k_{\max} = N/2+1`, where :math:`N = \lfloor T/\Delta t \rfloor` - is the number of samples in the time domain, but this can be set manually - to a smaller value. - - The normalization factor :math:`\alpha` is: - - .. math:: - - \alpha = \prod_{i} \frac{1}{\left(\pi T\right)^{N/2} - \prod_{k=k_\mathrm{min}}^{k_{\mathrm{max}}} S^{(i)}_n[k]}, - - where the product is over the number of detectors. By default, the - normalization constant is not included in the log likelihood, but it can - be turned on using the ``normalize`` keyword argument. - - Note that the log likelihood ratio has fewer terms than the log likelihood, - since the normalization and :math:`\left` terms cancel: - - .. math:: - - \log \mathcal{L}(\Theta) = \sum_i \left[ - \left - - \frac{1}{2} \left \right] - - Upon initialization, the data is whitened using the given PSDs. If no PSDs - are given the data and waveforms returned by the waveform generator are - assumed to be whitened. - - For more details on initialization parameters and definition of terms, see - :py:class:`models.BaseDataModel`. - - Parameters - ---------- - variable_params : (tuple of) string(s) - A tuple of parameter names that will be varied. - data : dict - A dictionary of data, in which the keys are the detector names and the - values are the data (assumed to be unwhitened). The list of keys must - match the waveform generator's detectors keys, and the epoch of every - data set must be the same as the waveform generator's epoch. - low_frequency_cutoff : dict - A dictionary of starting frequencies, in which the keys are the - detector names and the values are the starting frequencies for the - respective detectors to be used for computing inner products. - psds : dict, optional - A dictionary of FrequencySeries keyed by the detector names. The - dictionary must have a psd for each detector specified in the data - dictionary. If provided, the inner products in each detector will be - weighted by 1/psd of that detector. - high_frequency_cutoff : dict, optional - A dictionary of ending frequencies, in which the keys are the - detector names and the values are the ending frequencies for the - respective detectors to be used for computing inner products. If not - provided, the minimum of the largest frequency stored in the data - and a given waveform will be used. - normalize : bool, optional - If True, the normalization factor :math:`alpha` will be included in the - log likelihood. Default is to not include it. - static_params : dict, optional - A dictionary of parameter names -> values to keep fixed. - \**kwargs : - All other keyword arguments are passed to ``BaseDataModel``. - - Examples - -------- - Create a signal, and set up the model using that signal: - - >>> from pycbc import psd as pypsd - >>> from pycbc.inference.models import GaussianNoise - >>> from pycbc.waveform.generator import (FDomainDetFrameGenerator, - ... FDomainCBCGenerator) - >>> seglen = 4 - >>> sample_rate = 2048 - >>> N = seglen*sample_rate/2+1 - >>> fmin = 30. - >>> static_params = {'approximant': 'IMRPhenomD', 'f_lower': fmin, - ... 'mass1': 38.6, 'mass2': 29.3, - ... 'spin1z': 0., 'spin2z': 0., 'ra': 1.37, 'dec': -1.26, - ... 'polarization': 2.76, 'distance': 3*500.} - >>> variable_params = ['tc'] - >>> tsig = 3.1 - >>> generator = FDomainDetFrameGenerator( - ... FDomainCBCGenerator, 0., detectors=['H1', 'L1'], - ... variable_args=variable_params, - ... delta_f=1./seglen, **static_params) - >>> signal = generator.generate(tc=tsig) - >>> psd = pypsd.aLIGOZeroDetHighPower(N, 1./seglen, 20.) - >>> psds = {'H1': psd, 'L1': psd} - >>> low_frequency_cutoff = {'H1': fmin, 'L1': fmin} - >>> model = GaussianNoise(variable_params, signal, low_frequency_cutoff, - psds=psds, static_params=static_params) - - Set the current position to the coalescence time of the signal: - - >>> model.update(tc=tsig) - - Now compute the log likelihood ratio and prior-weighted likelihood ratio; - since we have not provided a prior, these should be equal to each other: - - >>> print('{:.2f}'.format(model.loglr)) - 282.43 - >>> print('{:.2f}'.format(model.logplr)) - 282.43 - - Print all of the default_stats: - - >>> print(',\n'.join(['{}: {:.2f}'.format(s, v) - ... for (s, v) in sorted(model.current_stats.items())])) - H1_cplx_loglr: 177.76+0.00j, - H1_optimal_snrsq: 355.52, - L1_cplx_loglr: 104.67+0.00j, - L1_optimal_snrsq: 209.35, - logjacobian: 0.00, - loglikelihood: 0.00, - loglr: 282.43, - logprior: 0.00 - - Compute the SNR; for this system and PSD, this should be approximately 24: - - >>> from pycbc.conversions import snr_from_loglr - >>> x = snr_from_loglr(model.loglr) - >>> print('{:.2f}'.format(x)) - 23.77 - - Since there is no noise, the SNR should be the same as the quadrature sum - of the optimal SNRs in each detector: - - >>> x = (model.det_optimal_snrsq('H1') + - ... model.det_optimal_snrsq('L1'))**0.5 - >>> print('{:.2f}'.format(x)) - 23.77 - - Toggle on the normalization constant: - - >>> model.normalize = True - >>> model.loglikelihood - 835397.8757405131 - - Using the same model, evaluate the log likelihood ratio at several points - in time and check that the max is at tsig: - - >>> import numpy - >>> times = numpy.linspace(tsig-1, tsig+1, num=101) - >>> loglrs = numpy.zeros(len(times)) - >>> for (ii, t) in enumerate(times): - ... model.update(tc=t) - ... loglrs[ii] = model.loglr - >>> print('tsig: {:.2f}, time of max loglr: {:.2f}'.format( - ... tsig, times[loglrs.argmax()])) - tsig: 3.10, time of max loglr: 3.10 - - Create a prior and use it (see distributions module for more details): - - >>> from pycbc import distributions - >>> uniform_prior = distributions.Uniform(tc=(tsig-0.2,tsig+0.2)) - >>> prior = distributions.JointDistribution(variable_params, uniform_prior) - >>> model = GaussianNoise(variable_params, - ... signal, low_frequency_cutoff, psds=psds, prior=prior, - ... static_params=static_params) - >>> model.update(tc=tsig) - >>> print('{:.2f}'.format(model.logplr)) - 283.35 - >>> print(',\n'.join(['{}: {:.2f}'.format(s, v) - ... for (s, v) in sorted(model.current_stats.items())])) - H1_cplx_loglr: 177.76+0.00j, - H1_optimal_snrsq: 355.52, - L1_cplx_loglr: 104.67+0.00j, - L1_optimal_snrsq: 209.35, - logjacobian: 0.00, - loglikelihood: 0.00, - loglr: 282.43, - logprior: 0.92 - - """ - name = 'gaussian_noise' - - def __init__(self, variable_params, data, low_frequency_cutoff, psds=None, - high_frequency_cutoff=None, normalize=False, - static_params=None, **kwargs): - # set up the boiler-plate attributes - super(GaussianNoise, self).__init__( - variable_params, data, low_frequency_cutoff, psds=psds, - high_frequency_cutoff=high_frequency_cutoff, normalize=normalize, - static_params=static_params, **kwargs) - # Determine if all data have the same sampling rate and segment length - if self.all_ifodata_same_rate_length: - # create a waveform generator for all ifos - self.waveform_generator = create_waveform_generator( - self.variable_params, self.data, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - gates=self.gates, **self.static_params) - else: - # create a waveform generator for each ifo respestively - self.waveform_generator = {} - for det in self.data: - self.waveform_generator[det] = create_waveform_generator( - self.variable_params, {det: self.data[det]}, - waveform_transforms=self.waveform_transforms, - recalibration=self.recalibration, - gates=self.gates, **self.static_params) - - @property - def _extra_stats(self): - """Adds ``loglr``, plus ``cplx_loglr`` and ``optimal_snrsq`` in each - detector.""" - return ['loglr'] + \ - ['{}_cplx_loglr'.format(det) for det in self._data] + \ - ['{}_optimal_snrsq'.format(det) for det in self._data] - - def _nowaveform_loglr(self): - """Convenience function to set loglr values if no waveform generated. - """ - for det in self._data: - setattr(self._current_stats, 'loglikelihood', -numpy.inf) - setattr(self._current_stats, '{}_cplx_loglr'.format(det), - -numpy.inf) - # snr can't be < 0 by definition, so return 0 - setattr(self._current_stats, '{}_optimal_snrsq'.format(det), 0.) - return -numpy.inf - - @property - def multi_signal_support(self): - """ The list of classes that this model supports in a multi-signal - likelihood - """ - return [type(self)] - - def multi_loglikelihood(self, models): - """ Calculate a multi-model (signal) likelihood - """ - # Generate the waveforms for each submodel - wfs = [] - for m in models + [self]: - wfs.append(m.get_waveforms()) - - # combine into a single waveform - combine = {} - for det in self.data: - mlen = max([len(x[det]) for x in wfs]) - [x[det].resize(mlen) for x in wfs] - combine[det] = sum([x[det] for x in wfs]) - - self._current_wfs = combine - loglr = self._loglr() - self._current_wfs = None - return loglr + self.lognl - - def get_waveforms(self): - """The waveforms generated using the current parameters. - - If the waveforms haven't been generated yet, they will be generated. - - Returns - ------- - dict : - Dictionary of detector names -> FrequencySeries. - """ - if self._current_wfs is None: - params = self.current_params - if self.all_ifodata_same_rate_length: - wfs = self.waveform_generator.generate(**params) - else: - wfs = {} - for det in self.data: - wfs.update(self.waveform_generator[det].generate(**params)) - self._current_wfs = wfs - return self._current_wfs - - def _loglr(self): - r"""Computes the log likelihood ratio, - - .. math:: - - \log \mathcal{L}(\Theta) = \sum_i - \left - - \frac{1}{2}\left, - - at the current parameter values :math:`\Theta`. - - Returns - ------- - float - The value of the log likelihood ratio. - """ - try: - wfs = self.get_waveforms() - except NoWaveformError: - return self._nowaveform_loglr() - except FailedWaveformError as e: - if self.ignore_failed_waveforms: - return self._nowaveform_loglr() - else: - raise e - - lr = 0. - for det, h in wfs.items(): - # the kmax of the waveforms may be different than internal kmax - kmax = min(len(h), self._kmax[det]) - if self._kmin[det] >= kmax: - # if the waveform terminates before the filtering low frequency - # cutoff, then the loglr is just 0 for this detector - cplx_hd = 0j - hh = 0. - else: - slc = slice(self._kmin[det], kmax) - # whiten the waveform - h[self._kmin[det]:kmax] *= self._weight[det][slc] - - # the inner products - cplx_hd = self._whitened_data[det][slc].inner(h[slc]) # - hh = h[slc].inner(h[slc]).real # < h, h> - cplx_loglr = cplx_hd - 0.5 * hh - # store - setattr(self._current_stats, '{}_optimal_snrsq'.format(det), hh) - setattr(self._current_stats, '{}_cplx_loglr'.format(det), - cplx_loglr) - lr += cplx_loglr.real - # also store the loglikelihood, to ensure it is populated in the - # current stats even if loglikelihood is never called - self._current_stats.loglikelihood = lr + self.lognl - return float(lr) - - def det_cplx_loglr(self, det): - """Returns the complex log likelihood ratio in the given detector. - - Parameters - ---------- - det : str - The name of the detector. - - Returns - ------- - complex float : - The complex log likelihood ratio. - """ - # try to get it from current stats - try: - return getattr(self._current_stats, '{}_cplx_loglr'.format(det)) - except AttributeError: - # hasn't been calculated yet; call loglr to do so - self._loglr() - # now try returning again - return getattr(self._current_stats, '{}_cplx_loglr'.format(det)) - - def det_optimal_snrsq(self, det): - """Returns the opitmal SNR squared in the given detector. - - Parameters - ---------- - det : str - The name of the detector. - - Returns - ------- - float : - The opimtal SNR squared. - """ - # try to get it from current stats - try: - return getattr(self._current_stats, '{}_optimal_snrsq'.format(det)) - except AttributeError: - # hasn't been calculated yet; call loglr to do so - self._loglr() - # now try returning again - return getattr(self._current_stats, '{}_optimal_snrsq'.format(det)) - - -# -# ============================================================================= -# -# Support functions -# -# ============================================================================= -# - - -def get_values_from_injection(cp, injection_file, update_cp=True): - """Replaces all FROM_INJECTION values in a config file with the - corresponding value from the injection. - - This looks for any options that start with ``FROM_INJECTION[:ARG]`` in - a config file. It then replaces that value with the corresponding value - from the injection file. An argument may be optionally provided, in which - case the argument will be retrieved from the injection file. Functions of - parameters in the injection file may be used; the syntax and functions - available is the same as the ``--parameters`` argument in executables - such as ``pycbc_inference_extract_samples``. If no ``ARG`` is provided, - then the option name will try to be retrieved from the injection. - - For example, - - .. code-block:: ini - - mass1 = FROM_INJECTION - - will cause ``mass1`` to be retrieved from the injection file, while: - - .. code-block:: ini - - mass1 = FROM_INJECTION:'primary_mass(mass1, mass2)' - - will cause the larger of mass1 and mass2 to be retrieved from the injection - file. Note that if spaces are in the argument, it must be encased in - single quotes. - - The injection file may contain only one injection. Otherwise, a ValueError - will be raised. - - Parameters - ---------- - cp : ConfigParser - The config file within which to replace values. - injection_file : str or None - The injection file to get values from. A ValueError will be raised - if there are any ``FROM_INJECTION`` values in the config file, and - injection file is None, or if there is more than one injection. - update_cp : bool, optional - Update the config parser with the replaced parameters. If False, - will just retrieve the parameter values to update, without updating - the config file. Default is True. - - Returns - ------- - list - The parameters that were replaced, as a tuple of section name, option, - value. - """ - lookfor = 'FROM_INJECTION' - # figure out what parameters need to be set - replace_params = [] - for sec in cp.sections(): - for opt in cp.options(sec): - val = cp.get(sec, opt) - splitvals = shlex.split(val) - replace_this = [] - for ii, subval in enumerate(splitvals): - if subval.startswith(lookfor): - # determine what we should retrieve from the injection - subval = subval.split(':', 1) - if len(subval) == 1: - subval = opt - else: - subval = subval[1] - replace_this.append((ii, subval)) - if replace_this: - replace_params.append((sec, opt, splitvals, replace_this)) - if replace_params: - # check that we have an injection file - if injection_file is None: - raise ValueError("One or values are set to {}, but no injection " - "file provided".format(lookfor)) - # load the injection file - inj = InjectionSet(injection_file).table.view(type=FieldArray) - # make sure there's only one injection provided - if inj.size > 1: - raise ValueError("One or more values are set to {}, but more than " - "one injection exists in the injection file." - .format(lookfor)) - # get the injection values to replace - for ii, (sec, opt, splitvals, replace_this) in enumerate(replace_params): - # replace the value in the shlex-splitted string with the value - # from the injection - for jj, arg in replace_this: - splitvals[jj] = str(inj[arg][0]) - # now rejoin the string... - # shlex will strip quotes around arguments; this can be problematic - # when rejoining if the the argument had a space in it. In python 3.8 - # there is a shlex.join function which properly rejoins things taking - # that into account. Since we need to continue to support earlier - # versions of python, the following kludge tries to account for that. - # If/when we drop support for all earlier versions of python, then the - # following can just be replaced by: - # replace_val = shlex.join(splitvals) - for jj, arg in enumerate(splitvals): - if ' ' in arg: - arg = "'" + arg + "'" - splitvals[jj] = arg - replace_val = ' '.join(splitvals) - replace_params[ii] = (sec, opt, replace_val) - # replace in the config file - if update_cp: - for (sec, opt, replace_val) in replace_params: - cp.set(sec, opt, replace_val) - return replace_params - - -def create_waveform_generator( - variable_params, data, waveform_transforms=None, - recalibration=None, gates=None, - generator_class=generator.FDomainDetFrameGenerator, - **static_params): - r"""Creates a waveform generator for use with a model. - - Parameters - ---------- - variable_params : list of str - The names of the parameters varied. - data : dict - Dictionary mapping detector names to either a - :py:class:`` or - :py:class:``. - waveform_transforms : list, optional - The list of transforms applied to convert variable parameters into - parameters that will be understood by the waveform generator. - recalibration : dict, optional - Dictionary mapping detector names to - :py:class:`` instances for - recalibrating data. - gates : dict of tuples, optional - Dictionary of detectors -> tuples of specifying gate times. The - sort of thing returned by :py:func:`pycbc.gate.gates_from_cli`. - generator_class : detector-frame fdomain generator, optional - Class to use for generating waveforms. Default is - :py:class:`waveform.generator.FDomainDetFrameGenerator`. - \**static_params : - All other keyword arguments are passed as static parameters to the - waveform generator. - - Returns - ------- - pycbc.waveform.FDomainDetFrameGenerator - A waveform generator for frequency domain generation. - """ - # the waveform generator will get the variable_params + the output - # of the waveform transforms, so we'll add them to the list of - # parameters - if waveform_transforms is not None: - wfoutputs = set.union(*[t.outputs - for t in waveform_transforms]) - else: - wfoutputs = set() - variable_params = list(variable_params) + list(wfoutputs) - # figure out what generator to use based on the approximant - try: - approximant = static_params['approximant'] - except KeyError: - raise ValueError("no approximant provided in the static args") - - generator_function = generator_class.select_rframe_generator(approximant) - # get data parameters; we'll just use one of the data to get the - # values, then check that all the others are the same - delta_f = None - for d in data.values(): - if delta_f is None: - delta_f = d.delta_f - delta_t = d.delta_t - start_time = d.start_time - else: - if not all([d.delta_f == delta_f, d.delta_t == delta_t, - d.start_time == start_time]): - raise ValueError("data must all have the same delta_t, " - "delta_f, and start_time") - waveform_generator = generator_class( - generator_function, epoch=start_time, - variable_args=variable_params, detectors=list(data.keys()), - delta_f=delta_f, delta_t=delta_t, - recalib=recalibration, gates=gates, - **static_params) - return waveform_generator diff --git a/pycbc/inference/models/.ipynb_checkpoints/hierarchical-checkpoint.py b/pycbc/inference/models/.ipynb_checkpoints/hierarchical-checkpoint.py deleted file mode 100644 index 560e216a2ca..00000000000 --- a/pycbc/inference/models/.ipynb_checkpoints/hierarchical-checkpoint.py +++ /dev/null @@ -1,566 +0,0 @@ -# Copyright (C) 2022 Collin Capano -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# - -"""Hierarchical model definitions.""" - -import shlex -import logging -from pycbc import transforms -from pycbc.workflow import WorkflowConfigParser -from .base import BaseModel - -# -# ============================================================================= -# -# Hierarhical model definition -# -# ============================================================================= -# - - -class HierarchicalModel(BaseModel): - r"""Model that is a combination of other models. - - Sub-models are treated as being independent of each other, although - they can share parameters. In other words, the hiearchical likelihood is: - - .. math:: - - p(\mathbf{D}|\mathbf{\vartheta}, \mathbf{H}) = - \prod_{I}^{K} p(\mathbf{d}_I|\mathbf{\vartheta}, H_{I}) - - Submodels are provided as a dictionary upon initialization with a unique - label assigned to each model, e.g., ``{'event1' -> model1, 'event2' -> - model2}``. Variable and static parameters that are specific to each - submodel should be prepended with ``{label}__``, where ``{label}__`` is the - label associated with the given submodel. Shared parameters across multiple - models have no labels prepended. To specify shared models over a subset of - models, separate models with an underscore. For example, - ``event1_event2__foo`` will result in ``foo`` being common between models - ``event1`` and ``event2``. For more details on parameter naming see - :py:class:`HierarchicalParam - `. - - All waveform and sampling transforms, as well as prior evaluation, are - handled by this model, not the sub-models. Parameters created by waveform - transforms should therefore also have sub-model names prepended to them, - to indicate which models they should be provided to for likelihood - evaluation. - - Parameters - ---------- - variable_params: (tuple of) string(s) - A tuple of parameter names that will be varied. - submodels: dict - Dictionary of model labels -> model instances of all the submodels. - \**kwargs : - All other keyword arguments are passed to - :py:class:`BaseModel `. - """ - name = 'hierarchical' - - def __init__(self, variable_params, submodels, **kwargs): - # sub models is assumed to be a dict of model labels -> model instances - self.submodels = submodels - # initialize standard attributes - super().__init__(variable_params, **kwargs) - # store a map of model labels -> parameters for quick look up later - self.param_map = map_params(self.hvariable_params) - # add any parameters created by waveform transforms - if self.waveform_transforms is not None: - derived_params = set() - derived_params.update(*[t.outputs - for t in self.waveform_transforms]) - # convert to hierarchical params - derived_params = map_params(hpiter(derived_params, - list(self.submodels.keys()))) - for lbl, pset in derived_params.items(): - self.param_map[lbl].update(pset) - # make sure the static parameters of all submodels are set correctly - self.static_param_map = map_params(self.hstatic_params.keys()) - # also create a map of model label -> extra stats created by each model - # stats are prepended with the model label. We'll include the - # loglikelihood returned by each submodel in the extra stats. - self.extra_stats_map = {} - self.__extra_stats = [] - for lbl, model in self.submodels.items(): - model.static_params = {p.subname: self.static_params[p.fullname] - for p in self.static_param_map[lbl]} - self.extra_stats_map.update(map_params([ - HierarchicalParam.from_subname(lbl, p) - for p in model._extra_stats+['loglikelihood']])) - self.__extra_stats += self.extra_stats_map[lbl] - # also make sure the model's sampling transforms and waveform - # transforms are not set, as these are handled by the hierarchical - # model - if model.sampling_transforms is not None: - raise ValueError("Model {} has sampling transforms set; " - "in a hierarchical analysis, these are " - "handled by the hiearchical model" - .format(lbl)) - if model.waveform_transforms is not None: - raise ValueError("Model {} has waveform transforms set; " - "in a hierarchical analysis, these are " - "handled by the hiearchical model" - .format(lbl)) - - @property - def hvariable_params(self): - """The variable params as a tuple of :py:class:`HierarchicalParam` - instances. - """ - return self._variable_params - - @property - def variable_params(self): - # converts variable params back to a set of strings before returning - return tuple(p.fullname for p in self._variable_params) - - @variable_params.setter - def variable_params(self, variable_params): - # overrides BaseModel's variable params to store the variable params - # as HierarchicalParam instances - if isinstance(variable_params, str): - variable_params = [variable_params] - self._variable_params = tuple(HierarchicalParam(p, self.submodels) - for p in variable_params) - - @property - def hstatic_params(self): - """The static params with :py:class:`HierarchicalParam` instances used - as dictionary keys. - """ - return self._static_params - - @property - def static_params(self): - # converts the static param keys back to strings - return {p.fullname: val for p, val in self._static_params.items()} - - @static_params.setter - def static_params(self, static_params): - if static_params is None: - static_params = {} - self._static_params = {HierarchicalParam(p, self.submodels): val - for p, val in static_params.items()} - - @property - def _extra_stats(self): - return [p.fullname for p in self.__extra_stats] - - @property - def _hextra_stats(self): - """The extra stats as :py:class:`HierarchicalParam` instances.""" - return self.__extra_stats - - def _loglikelihood(self): - # takes the sum of the constitutent models' loglikelihoods - logl = 0. - for lbl, model in self.submodels.items(): - # update the model with the current params. This is done here - # instead of in `update` because waveform transforms are not - # applied until the loglikelihood function is called - model.update(**{p.subname: self.current_params[p.fullname] - for p in self.param_map[lbl]}) - # now get the loglikelihood from the model - sublogl = model.loglikelihood - # store the extra stats - mstats = model.current_stats - for stat in self.extra_stats_map[lbl]: - setattr(self._current_stats, stat, mstats[stat.subname]) - # add to the total loglikelihood - logl += sublogl - return logl - - def write_metadata(self, fp, group=None): - """Adds data to the metadata that's written. - - Parameters - ---------- - fp : pycbc.inference.io.BaseInferenceFile instance - The inference file to write to. - group : str, optional - If provided, the metadata will be written to the attrs specified - by group, i.e., to ``fp[group].attrs``. Otherwise, metadata is - written to the top-level attrs (``fp.attrs``). - - """ - # write information about self - super().write_metadata(fp, group=group) - # write information about each submodel into a different group for - # each one - if group is None or group == '/': - prefix = '' - else: - prefix = group+'/' - for lbl, model in self.submodels.items(): - model.write_metadata(fp, group=prefix+lbl) - - @classmethod - def from_config(cls, cp, **kwargs): - r"""Initializes an instance of this class from the given config file. - - Sub-models are initialized before initializing this class. The model - section must have a ``submodels`` argument that lists the names of all - the submodels to generate as a space-separated list. Each sub-model - should have its own ``[{label}__model]`` section that sets up the - model for that sub-model. For example: - - .. code-block:: ini - - [model] - name = hiearchical - submodels = event1 event2 - - [event1__model] - - - [event2__model] - - - Similarly, all other sections that are specific to a model should start - with the model's label. All sections starting with a model's label will - be passed to that model's ``from_config`` method with the label removed - from the section name. For example, if a sub-model requires a data - section to be specified, it should be titled ``[{label}__data]``. Upon - initialization, the ``{label}__`` will be stripped from the section - header and passed to the model. - - No model labels should preceed the ``variable_params``, - ``static_params``, ``waveform_transforms``, or ``sampling_transforms`` - sections. Instead, the parameters specified in these sections should - follow the naming conventions described in :py:class:`HierachicalParam` - to determine which sub-model(s) they belong to. (Sampling parameters - can follow any naming convention, as they are only handled by the - hierarchical model.) This is because the hierarchical model handles - all transforms, communication with the sampler, file IO, and prior - calculation. Only sub-model's loglikelihood functions are called. - - Metadata for each sub-model is written to the output hdf file under - groups given by the sub-model label. For example, if we have two - submodels labelled ``event1`` and ``event2``, there will be groups - with the same names in the top level of the output that contain that - model's subdata. For instance, if event1 used the ``gaussian_noise`` - model, the GW data and PSDs will be found in ``event1/data`` and the - low frequency cutoff used for that model will be in the ``attrs`` of - the ``event1`` group. - - Parameters - ---------- - cp : WorkflowConfigParser - Config file parser to read. - \**kwargs : - All additional keyword arguments are passed to the class. Any - provided keyword will override what is in the config file. - """ - # we need the read from config function from the init; to prevent - # circular imports, we import it here - from pycbc.inference.models import read_from_config - # get the submodels - submodel_lbls = shlex.split(cp.get('model', 'submodels')) - # sort parameters by model - vparam_map = map_params(hpiter(cp.options('variable_params'), - submodel_lbls)) - sparam_map = map_params(hpiter(cp.options('static_params'), - submodel_lbls)) - - # we'll need any waveform transforms for the initializing sub-models, - # as the underlying models will receive the output of those transforms - if any(cp.get_subsections('waveform_transforms')): - waveform_transforms = transforms.read_transforms_from_config( - cp, 'waveform_transforms') - wfoutputs = set.union(*[t.outputs - for t in waveform_transforms]) - wfparam_map = map_params(hpiter(wfoutputs, submodel_lbls)) - else: - wfparam_map = {lbl: [] for lbl in submodel_lbls} - # initialize the models - submodels = {} - logging.info("Loading submodels") - for lbl in submodel_lbls: - logging.info("============= %s =============", lbl) - # create a config parser to pass to the model - subcp = WorkflowConfigParser() - # copy sections over that start with the model label (this should - # include the [model] section for that model) - copy_sections = [ - HierarchicalParam(sec, submodel_lbls) - for sec in cp.sections() if lbl in - sec.split('-')[0].split(HierarchicalParam.delim, 1)[0]] - for sec in copy_sections: - # check that the user isn't trying to set variable or static - # params for the model (we won't worry about waveform or - # sampling transforms here, since that is checked for in the - # __init__) - if sec.subname in ['variable_params', 'static_params']: - raise ValueError("Section {} found in the config file; " - "[variable_params] and [static_params] " - "sections should not include model " - "labels. To specify parameters unique to " - "one or more sub-models, prepend the " - "individual parameter names with the " - "model label. See HierarchicalParam for " - "details.".format(sec)) - subcp.add_section(sec.subname) - for opt, val in cp.items(sec): - subcp.set(sec.subname, opt, val) - # set the static params - subcp.add_section('static_params') - for param in sparam_map[lbl]: - subcp.set('static_params', param.subname, - cp.get('static_params', param.fullname)) - # set the variable params: for now we'll just set all the - # variable params as static params - # so that the model doesn't raise an error looking for - # prior sections. We'll then manually set the variable - # params after the model is initialized - - subcp.add_section('variable_params') - for param in vparam_map[lbl]: - subcp.set('static_params', param.subname, 'REPLACE') - # add the outputs from the waveform transforms - for param in wfparam_map[lbl]: - subcp.set('static_params', param.subname, 'REPLACE') - - # initialize - submodel = read_from_config(subcp) - # move the static params back to variable - for p in vparam_map[lbl]: - submodel.static_params.pop(p.subname) - submodel.variable_params = tuple(p.subname - for p in vparam_map[lbl]) - # remove the waveform transform parameters - for p in wfparam_map[lbl]: - submodel.static_params.pop(p.subname) - # store - submodels[lbl] = submodel - logging.info("") - # now load the model - logging.info("Loading hierarchical model") - return super().from_config(cp, submodels=submodels) - - -class HierarchicalParam(str): - """Sub-class of str for hierarchical parameter names. - - This adds attributes that keep track of the model label(s) the parameter - is associated with, along with the name that is passed to the models. - - The following conventions are used for parsing parameter names: - - * Model labels and parameter names are separated by the ``delim`` class - attribute, which by default is ``__``, e.g., ``event1__mass``. - * Multiple model labels can be provided by separating the model labels - with the ``model_delim`` class attribute, which by default is ``_``, - e.g., ``event1_event2__mass``. Note that this means that individual - model labels cannot contain ``_``, else they'll be parsed as separate - models. - * Parameters that have no model labels prepended to them (i.e., there - is no ``__`` in the name) are common to all models. - - These parsing rules are applied by the :py:meth:`HierarchicalParam.parse` - method. - - Parameters - ---------- - fullname : str - Name of the hierarchical parameter. Should have format - ``{model1}[_{model2}[_{...}]]__{param}``. - possible_models : set of str - The possible sub-models a parameter can belong to. Should a set of - model labels. - - Attributes - ---------- - fullname : str - The full name of the parameter, including model labels. For example, - ``e1_e2__foo``. - models : set - The model labels the parameter is associated with. For example, - ``e1_e2__foo`` yields models ``e1, e2``. - subname : str - The name of the parameter without the model labels prepended to it. - For example, ``e1_e2__foo`` yields ``foo``. - """ - delim = '__' - model_delim = '_' - - def __new__(cls, fullname, possible_models): - fullname = str(fullname) - obj = str.__new__(cls, fullname) - obj.fullname = fullname - models, subp = HierarchicalParam.parse(fullname, possible_models) - obj.models = models - obj.subname = subp - return obj - - @classmethod - def from_subname(cls, model_label, subname): - """Creates a HierarchicalParam from the given subname and model label. - """ - return cls(cls.delim.join([model_label, subname]), set([model_label])) - - @classmethod - def parse(cls, fullname, possible_models): - """Parses the full parameter name into the models the parameter is - associated with and the parameter name that is passed to the models. - - Parameters - ---------- - fullname : str - The full name of the parameter, which includes both the model - label(s) and the parameter name. - possible_models : set - Set of model labels the parameter can be associated with. - - Returns - ------- - models : list - List of the model labels the parameter is associated with. - subp : str - Parameter name that is passed to the models. This is the parameter - name with the model label(s) stripped from it. - """ - # make sure possible models is a set - possible_models = set(possible_models) - p = fullname.split(cls.delim, 1) - if len(p) == 1: - # is a global fullname, associate with all - subp = fullname - models = possible_models.copy() - else: - models, subp = p - # convert into set of model label(s) - models = set(models.split(cls.model_delim)) - # make sure the given labels are in the list of possible models - unknown = models - possible_models - if any(unknown): - raise ValueError('unrecognized model label(s) {} present in ' - 'parameter {}'.format(', '.join(unknown), - fullname)) - return models, subp - - -def hpiter(params, possible_models): - """Turns a list of parameter strings into a list of HierarchicalParams. - - Parameters - ---------- - params : list of str - List of parameter names. - possible_models : set - Set of model labels the parameters can be associated with. - - Returns - ------- - iterator : - Iterator of :py:class:`HierarchicalParam` instances. - """ - return map(lambda x: HierarchicalParam(x, possible_models), params) - - -def map_params(params): - """Creates a map of models -> parameters. - - Parameters - ---------- - params : list of HierarchicalParam instances - The list of hierarchical parameter names to parse. - - Returns - ------- - dict : - Dictionary of model labels -> associated parameters. - """ - param_map = {} - for p in params: - for lbl in p.models: - try: - param_map[lbl].update([p]) - except KeyError: - param_map[lbl] = set([p]) - return param_map - - -class MultiSignalModel(HierarchicalModel): - """ Model for multiple signals which share data - - Sub models are treated as if the signals overlap in data. This requires - constituent models to implement a specific method to handle this case. - All models must be of the same type or the specific model is responsible - for implement cross-compatibility with another model. Each model h_i is - responsible for calculating its own loglikelihood ratio for itself, and - must also implement a method to calculate crossterms of the form - which arise from the full calculation of . - This model inherits from the HierarchicalModel so the syntax for - configuration files is the same. The primary model is used to determine - the noise terms , which by default will be the first model used. - """ - name = 'multi_signal' - - def __init__(self, variable_params, submodels, **kwargs): - super().__init__(variable_params, submodels, **kwargs) - - # Check what models each model supports - support = {} - ctypes = set() # The set of models we need to completely support - for lbl in self.submodels: - model = self.submodels[lbl] - - ctypes.add(type(model)) - if hasattr(model, 'multi_signal_support'): - support[lbl] = set(model.multi_signal_support) - - # pick the primary model if it supports the set of constituent models - for lbl in support: - if ctypes <= support[lbl]: - self.primary_model = lbl - logging.info('MultiSignalModel: PrimaryModel == %s', lbl) - break - else: - # Oh, no, we don't support this combo! - raise RuntimeError("It looks like the combination of models, {}," - "for the MultiSignal model isn't supported by" - "any of the constituent models.".format(ctypes)) - - self.other_models = self.submodels.copy() - self.other_models.pop(self.primary_model) - self.other_models = list(self.other_models.values()) - - def _loglikelihood(self): - for lbl, model in self.submodels.items(): - # Update the parameters of each - model.update(**{p.subname: self.current_params[p.fullname] - for p in self.param_map[lbl]}) - - # Calculate the combined loglikelihood - p = self.primary_model - logl = self.submodels[p].multi_loglikelihood(self.other_models) - - # store any extra stats from the submodels - for lbl, model in self.submodels.items(): - mstats = model.current_stats - for stat in self.extra_stats_map[lbl]: - setattr(self._current_stats, stat, mstats[stat.subname]) - return logl diff --git a/pycbc/inference/sampler/.ipynb_checkpoints/dynesty-checkpoint.py b/pycbc/inference/sampler/.ipynb_checkpoints/dynesty-checkpoint.py deleted file mode 100644 index 6b1286a4afd..00000000000 --- a/pycbc/inference/sampler/.ipynb_checkpoints/dynesty-checkpoint.py +++ /dev/null @@ -1,649 +0,0 @@ -# Copyright (C) 2019 Collin Capano, Sumit Kumar, Prayush Kumar -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# -""" -This modules provides classes and functions for using the dynesty sampler -packages for parameter estimation. -""" - -import logging -import time -import numpy -import dynesty, dynesty.dynesty, dynesty.nestedsamplers -from pycbc.pool import choose_pool -from dynesty import utils as dyfunc -from pycbc.inference.io import (DynestyFile, validate_checkpoint_files, - loadfile) -from .base import (BaseSampler, setup_output) -from .base_mcmc import get_optional_arg_from_config -from .base_cube import setup_calls -from .. import models - - -# -# ============================================================================= -# -# Samplers -# -# ============================================================================= -# - -class DynestySampler(BaseSampler): - """This class is used to construct an Dynesty sampler from the dynesty - package. - - Parameters - ---------- - model : model - A model from ``pycbc.inference.models``. - nlive : int - Number of live points to use in sampler. - pool : function with map, Optional - A provider of a map function that allows a function call to be run - over multiple sets of arguments and possibly maps them to - cores/nodes/etc. - """ - name = "dynesty" - _io = DynestyFile - - def __init__(self, model, nlive, nprocesses=1, - checkpoint_time_interval=None, maxcall=None, - loglikelihood_function=None, use_mpi=False, - no_save_state=False, - run_kwds=None, - extra_kwds=None, - internal_kwds=None, - **kwargs): - - self.model = model - self.no_save_state = no_save_state - log_likelihood_call, prior_call = setup_calls( - model, - loglikelihood_function=loglikelihood_function, - copy_prior=True) - # Set up the pool - self.pool = choose_pool(mpi=use_mpi, processes=nprocesses) - - self.maxcall = maxcall - self.checkpoint_time_interval = checkpoint_time_interval - self.run_kwds = {} if run_kwds is None else run_kwds - self.extra_kwds = {} if extra_kwds is None else extra_kwds - self.internal_kwds = {} if internal_kwds is None else internal_kwds - self.nlive = nlive - self.names = model.sampling_params - self.ndim = len(model.sampling_params) - self.checkpoint_file = None - # Enable checkpointing if checkpoint_time_interval is set in config - # file in sampler section - if self.checkpoint_time_interval: - self.run_with_checkpoint = True - if self.maxcall is None: - self.maxcall = 5000 * self.pool.size - logging.info("Checkpointing enabled, will verify every %s calls" - " and try to checkpoint every %s seconds", - self.maxcall, self.checkpoint_time_interval) - else: - self.run_with_checkpoint = False - - # Check for cyclic boundaries - periodic = [] - cyclic = self.model.prior_distribution.cyclic - for i, param in enumerate(self.variable_params): - if param in cyclic: - logging.info('Param: %s will be cyclic', param) - periodic.append(i) - - if len(periodic) == 0: - periodic = None - - # Check for reflected boundaries. Dynesty only supports - # reflection on both min and max of boundary. - reflective = [] - reflect = self.model.prior_distribution.well_reflected - for i, param in enumerate(self.variable_params): - if param in reflect: - logging.info("Param: %s will be well reflected", param) - reflective.append(i) - - if len(reflective) == 0: - reflective = None - - if 'sample' in extra_kwds: - if 'rwalk2' in extra_kwds['sample']: - dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_mod - dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_mod - extra_kwds['sample'] = 'rwalk' - - if self.nlive < 0: - # Interpret a negative input value for the number of live points - # (which is clearly an invalid input in all senses) - # as the desire to dynamically determine that number - self._sampler = dynesty.DynamicNestedSampler(log_likelihood_call, - prior_call, self.ndim, - pool=self.pool, - reflective=reflective, - periodic=periodic, - **extra_kwds) - self.run_with_checkpoint = False - logging.info("Checkpointing not currently supported with" - "DYNAMIC nested sampler") - else: - self._sampler = dynesty.NestedSampler(log_likelihood_call, - prior_call, self.ndim, - nlive=self.nlive, - reflective=reflective, - periodic=periodic, - pool=self.pool, **extra_kwds) - self._sampler.kwargs.update(internal_kwds) - - # properties of the internal sampler which should not be pickled - self.no_pickle = ['loglikelihood', - 'prior_transform', - 'propose_point', - 'update_proposal', - '_UPDATE', '_PROPOSE', - 'evolve_point', 'use_pool', 'queue_size', - 'use_pool_ptform', 'use_pool_logl', - 'use_pool_evolve', 'use_pool_update', - 'pool', 'M'] - - def run(self): - diff_niter = 1 - if self.run_with_checkpoint is True: - n_checkpointing = 1 - t0 = time.time() - it = self._sampler.it - - logging.info('Starting from iteration: %s', it) - while diff_niter != 0: - self._sampler.run_nested(maxcall=self.maxcall, **self.run_kwds) - - delta_t = time.time() - t0 - diff_niter = self._sampler.it - it - logging.info("Checking if we should checkpoint: %.2f s", delta_t) - - if delta_t >= self.checkpoint_time_interval: - logging.info('Checkpointing N={}'.format(n_checkpointing)) - self.checkpoint() - n_checkpointing += 1 - t0 = time.time() - it = self._sampler.it - else: - self._sampler.run_nested(**self.run_kwds) - - @property - def io(self): - return self._io - - @property - def niterations(self): - return len(tuple(self.samples.values())[0]) - - @classmethod - def from_config(cls, cp, model, output_file=None, nprocesses=1, - use_mpi=False, loglikelihood_function=None): - """Loads the sampler from the given config file. Many options are - directly passed to the underlying dynesty sampler, see the official - dynesty documentation for more details on these. - - The following options are retrieved in the ``[sampler]`` section: - - * ``name = STR``: - Required. This must match the sampler's name. - * ``maxiter = INT``: - The maximum number of iterations to run. - * ``dlogz = FLOAT``: - The target dlogz stopping condition. - * ``logl_max = FLOAT``: - The maximum logl stopping condition. - * ``n_effective = INT``: - Target effective number of samples stopping condition - * ``sample = STR``: - The method to sample the space. Should be one of 'uniform', - 'rwalk', 'rwalk2' (a modified version of rwalk), or 'slice'. - * ``walk = INT``: - Used for some of the walk methods. Sets the minimum number of - steps to take when evolving a point. - * ``maxmcmc = INT``: - Used for some of the walk methods. Sets the maximum number of steps - to take when evolving a point. - * ``nact = INT``: - used for some of the walk methods. Sets number of autorcorrelation - lengths before terminating evolution of a point. - * ``first_update_min_ncall = INT``: - The minimum number of calls before updating the bounding region - for the first time. - * ``first_update_min_neff = FLOAT``: - Don't update the the bounding region untill the efficiency drops - below this value. - * ``bound = STR``: - The method of bounding of the prior volume. - Should be one of 'single', 'balls', 'cubes', 'multi' or 'none'. - * ``update_interval = INT``: - Number of iterations between updating the bounding regions - * ``enlarge = FLOAT``: - Factor to enlarge the bonding region. - * ``bootstrap = INT``: - The number of bootstrap iterations to determine the enlargement - factor. - * ``maxcall = INT``: - The maximum number of calls before checking if we should checkpoint - * ``checkpoint_time_interval``: - Sets the time in seconds between checkpointing. - * ``loglikelihood-function``: - The attribute of the model to use for the loglikelihood. If - not provided, will default to ``loglikelihood``. - - Parameters - ---------- - cp : WorkflowConfigParser instance - Config file object to parse. - model : pycbc.inference.model.BaseModel instance - The model to use. - output_file : str, optional - The name of the output file to checkpoint and write results to. - nprocesses : int, optional - The number of parallel processes to use. Default is 1. - use_mpi : bool, optional - Use MPI for parallelization. Default is False. - - Returns - ------- - DynestySampler : - The sampler instance. - """ - section = "sampler" - # check name - assert cp.get(section, "name") == cls.name, ( - "name in section [sampler] must match mine") - # get the number of live points to use - nlive = int(cp.get(section, "nlive")) - loglikelihood_function = \ - get_optional_arg_from_config(cp, section, 'loglikelihood-function') - - no_save_state = cp.has_option(section, 'no-save-state') - - # optional run_nested arguments for dynesty - rargs = {'maxiter': int, - 'dlogz': float, - 'logl_max': float, - 'n_effective': int, - } - - # optional arguments for dynesty - cargs = {'bound': str, - 'bootstrap': int, - 'enlarge': float, - 'update_interval': float, - 'sample': str, - 'first_update_min_ncall': int, - 'first_update_min_eff': float, - 'walks': int, - } - - # optional arguments that must be set internally - internal_args = { - 'maxmcmc': int, - 'nact': int, - } - - extra = {} - run_extra = {} - internal_extra = {} - for args, argt in [(extra, cargs), - (run_extra, rargs), - (internal_extra, internal_args), - ]: - for karg in argt: - if cp.has_option(section, karg): - args[karg] = argt[karg](cp.get(section, karg)) - - #This arg needs to be a dict - first_update = {} - if 'first_update_min_ncall' in extra: - first_update['min_ncall'] = extra.pop('first_update_min_ncall') - logging.info('First update: min_ncall:%s', - first_update['min_ncall']) - if 'first_update_min_eff' in extra: - first_update['min_eff'] = extra.pop('first_update_min_eff') - logging.info('First update: min_eff:%s', first_update['min_eff']) - extra['first_update'] = first_update - - # populate options for checkpointing - checkpoint_time_interval = None - maxcall = None - if cp.has_option(section, 'checkpoint_time_interval'): - ck_time = float(cp.get(section, 'checkpoint_time_interval')) - checkpoint_time_interval = ck_time - if cp.has_option(section, 'maxcall'): - maxcall = int(cp.get(section, 'maxcall')) - - obj = cls(model, nlive=nlive, nprocesses=nprocesses, - loglikelihood_function=loglikelihood_function, - checkpoint_time_interval=checkpoint_time_interval, - maxcall=maxcall, - no_save_state=no_save_state, - use_mpi=use_mpi, run_kwds=run_extra, - extra_kwds=extra, - internal_kwds=internal_extra,) - setup_output(obj, output_file, check_nsamples=False) - - if not obj.new_checkpoint: - obj.resume_from_checkpoint() - return obj - - def checkpoint(self): - """Checkpoint function for dynesty sampler - """ - # Dynesty has its own __getstate__ which deletes - # random state information and the pool - saved = {} - for key in self.no_pickle: - if hasattr(self._sampler, key): - saved[key] = getattr(self._sampler, key) - setattr(self._sampler, key, None) - for fn in [self.checkpoint_file, self.backup_file]: - with self.io(fn, "a") as fp: - # Write random state - fp.write_random_state() - - # Write pickled data - fp.write_pickled_data_into_checkpoint_file(self._sampler) - - self.write_results(fn) - - # Restore properties that couldn't be pickled if we are continuing - for key in saved: - setattr(self._sampler, key, saved[key]) - - def resume_from_checkpoint(self): - try: - with loadfile(self.checkpoint_file, 'r') as fp: - sampler = fp.read_pickled_data_from_checkpoint_file() - - for key in sampler.__dict__: - if key not in self.no_pickle: - value = getattr(sampler, key) - setattr(self._sampler, key, value) - - self.set_state_from_file(self.checkpoint_file) - logging.info("Found valid checkpoint file: %s", - self.checkpoint_file) - except Exception as e: - print(e) - logging.info("Failed to load checkpoint file") - - def set_state_from_file(self, filename): - """Sets the state of the sampler back to the instance saved in a file. - """ - with self.io(filename, 'r') as fp: - state = fp.read_random_state() - # Dynesty handles most randomeness through rstate which is - # pickled along with the class instance - numpy.random.set_state(state) - - def finalize(self): - """Finalze and write it to the results file - """ - logz = self._sampler.results.logz[-1:][0] - dlogz = self._sampler.results.logzerr[-1:][0] - logging.info("log Z, dlog Z: {}, {}".format(logz, dlogz)) - - if self.no_save_state: - self.write_results(self.checkpoint_file) - else: - self.checkpoint() - logging.info("Validating checkpoint and backup files") - checkpoint_valid = validate_checkpoint_files( - self.checkpoint_file, self.backup_file, check_nsamples=False) - if not checkpoint_valid: - raise IOError("error writing to checkpoint file") - - @property - def samples(self): - """Returns raw nested samples - """ - results = self._sampler.results - samples = results.samples - nest_samp = {} - for i, param in enumerate(self.variable_params): - nest_samp[param] = samples[:, i] - nest_samp['logwt'] = results.logwt - nest_samp['loglikelihood'] = results.logl - return nest_samp - - def set_initial_conditions(self, initial_distribution=None, - samples_file=None): - """Sets up the starting point for the sampler. - - Should also set the sampler's random state. - """ - pass - - def write_results(self, filename): - """Writes samples, model stats, acceptance fraction, and random state - to the given file. - - Parameters - ----------- - filename : str - The file to write to. The file is opened using the ``io`` class - in an an append state. - """ - with self.io(filename, 'a') as fp: - # Write nested samples - fp.write_raw_samples(self.samples) - - # Write logz and dlogz - logz = self._sampler.results.logz[-1:][0] - dlogz = self._sampler.results.logzerr[-1:][0] - fp.write_logevidence(logz, dlogz) - - @property - def model_stats(self): - pass - - @property - def logz(self): - """ - return bayesian evidence estimated by - dynesty sampler - """ - return self._sampler.results.logz[-1:][0] - - @property - def logz_err(self): - """ - return error in bayesian evidence estimated by - dynesty sampler - """ - return self._sampler.results.logzerr[-1:][0] - - -def sample_rwalk_mod(args): - """ Modified version of dynesty.sampling.sample_rwalk - - Adapted from version used in bilby/dynesty - """ - try: - # dynesty <= 1.1 - from dynesty.utils import unitcheck, reflect - - # Unzipping. - (u, loglstar, axes, scale, - prior_transform, loglikelihood, kwargs) = args - - except ImportError: - # dynest >= 1.2 - from dynesty.utils import unitcheck, apply_reflect as reflect - - (u, loglstar, axes, scale, - prior_transform, loglikelihood, _, kwargs) = args - - rstate = numpy.random - - # Bounds - nonbounded = kwargs.get('nonbounded', None) - periodic = kwargs.get('periodic', None) - reflective = kwargs.get('reflective', None) - - # Setup. - n = len(u) - walks = kwargs.get('walks', 10 * n) # minimum number of steps - maxmcmc = kwargs.get('maxmcmc', 2000) # Maximum number of steps - nact = kwargs.get('nact', 5) # Number of ACT - old_act = kwargs.get('old_act', walks) - - # Initialize internal variables - accept = 0 - reject = 0 - nfail = 0 - act = numpy.inf - u_list = [] - v_list = [] - logl_list = [] - - ii = 0 - while ii < nact * act: - ii += 1 - - # Propose a direction on the unit n-sphere. - drhat = rstate.randn(n) - drhat /= numpy.linalg.norm(drhat) - - # Scale based on dimensionality. - dr = drhat * rstate.rand() ** (1.0 / n) - - # Transform to proposal distribution. - du = numpy.dot(axes, dr) - u_prop = u + scale * du - - # Wrap periodic parameters - if periodic is not None: - u_prop[periodic] = numpy.mod(u_prop[periodic], 1) - # Reflect - if reflective is not None: - u_prop[reflective] = reflect(u_prop[reflective]) - - # Check unit cube constraints. - if u.max() < 0: - break - if unitcheck(u_prop, nonbounded): - pass - else: - nfail += 1 - # Only start appending to the chain once a single jump is made - if accept > 0: - u_list.append(u_list[-1]) - v_list.append(v_list[-1]) - logl_list.append(logl_list[-1]) - continue - - # Check proposed point. - v_prop = prior_transform(numpy.array(u_prop)) - logl_prop = loglikelihood(numpy.array(v_prop)) - if logl_prop > loglstar: - u = u_prop - v = v_prop - logl = logl_prop - accept += 1 - u_list.append(u) - v_list.append(v) - logl_list.append(logl) - else: - reject += 1 - # Only start appending to the chain once a single jump is made - if accept > 0: - u_list.append(u_list[-1]) - v_list.append(v_list[-1]) - logl_list.append(logl_list[-1]) - - # If we've taken the minimum number of steps, calculate the ACT - if accept + reject > walks: - act = estimate_nmcmc( - accept_ratio=accept / (accept + reject + nfail), - old_act=old_act, maxmcmc=maxmcmc) - - # If we've taken too many likelihood evaluations then break - if accept + reject > maxmcmc: - logging.warning( - "Hit maximum number of walks {} with accept={}, reject={}, " - "and nfail={} try increasing maxmcmc" - .format(maxmcmc, accept, reject, nfail)) - break - - # If the act is finite, pick randomly from within the chain - if numpy.isfinite(act) and int(.5 * nact * act) < len(u_list): - idx = numpy.random.randint(int(.5 * nact * act), len(u_list)) - u = u_list[idx] - v = v_list[idx] - logl = logl_list[idx] - else: - logging.debug("Unable to find a new point using walk: " - "returning a random point") - u = numpy.random.uniform(size=n) - v = prior_transform(u) - logl = loglikelihood(v) - - blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale} - kwargs["old_act"] = act - - ncall = accept + reject - return u, v, logl, ncall, blob - - -def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None): - """Estimate autocorrelation length of chain using acceptance fraction - - Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapated from - CPNest: - - * https://github.com/johnveitch/cpnest/blob/master/cpnest/sampler.py - * https://github.com/farr/Ensemble.jl - - Parameters - ---------- - accept_ratio: float [0, 1] - Ratio of the number of accepted points to the total number of points - old_act: int - The ACT of the last iteration - maxmcmc: int - The maximum length of the MCMC chain to use - safety: int - A safety factor applied in the calculation - tau: int (optional) - The ACT, if given, otherwise estimated. - """ - if tau is None: - tau = maxmcmc / safety - - if accept_ratio == 0.0: - Nmcmc_exact = (1 + 1 / tau) * old_act - else: - Nmcmc_exact = ( - (1. - 1. / tau) * old_act + - (safety / tau) * (2. / accept_ratio - 1.) - ) - Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc)) - return max(safety, int(Nmcmc_exact)) - diff --git a/pycbc/results/.ipynb_checkpoints/scatter_histograms-checkpoint.py b/pycbc/results/.ipynb_checkpoints/scatter_histograms-checkpoint.py deleted file mode 100644 index e89bccfd08b..00000000000 --- a/pycbc/results/.ipynb_checkpoints/scatter_histograms-checkpoint.py +++ /dev/null @@ -1,867 +0,0 @@ -# Copyright (C) 2016 Miriam Cabero Mueller, Collin Capano -# -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. - - -# -# ============================================================================= -# -# Preamble -# -# ============================================================================= -# -""" -Module to generate figures with scatter plots and histograms. -""" - -import itertools -import sys - -import numpy - -import scipy.stats - -import matplotlib - -# Only if a backend is not already set ... This should really *not* be done -# here, but in the executables you should set matplotlib.use() -# This matches the check that matplotlib does internally, but this *may* be -# version dependenant. If this is a problem then remove this and control from -# the executables directly. -if 'matplotlib.backends' not in sys.modules: # nopep8 - matplotlib.use('agg') - -from matplotlib import (offsetbox, pyplot, gridspec) - -from pycbc.results import str_utils -from pycbc.io import FieldArray - - -def create_axes_grid(parameters, labels=None, height_ratios=None, - width_ratios=None, no_diagonals=False): - """Given a list of parameters, creates a figure with an axis for - every possible combination of the parameters. - - Parameters - ---------- - parameters : list - Names of the variables to be plotted. - labels : {None, dict}, optional - A dictionary of parameters -> parameter labels. - height_ratios : {None, list}, optional - Set the height ratios of the axes; see `matplotlib.gridspec.GridSpec` - for details. - width_ratios : {None, list}, optional - Set the width ratios of the axes; see `matplotlib.gridspec.GridSpec` - for details. - no_diagonals : {False, bool}, optional - Do not produce axes for the same parameter on both axes. - - Returns - ------- - fig : pyplot.figure - The figure that was created. - axis_dict : dict - A dictionary mapping the parameter combinations to the axis and their - location in the subplots grid; i.e., the key, values are: - `{('param1', 'param2'): (pyplot.axes, row index, column index)}` - """ - if labels is None: - labels = {p: p for p in parameters} - elif any(p not in labels for p in parameters): - raise ValueError("labels must be provided for all parameters") - # Create figure with adequate size for number of parameters. - ndim = len(parameters) - if no_diagonals: - ndim -= 1 - if ndim < 3: - fsize = (8, 7) - else: - fsize = (ndim*3 - 1, ndim*3 - 2) - fig = pyplot.figure(figsize=fsize) - # create the axis grid - gs = gridspec.GridSpec(ndim, ndim, width_ratios=width_ratios, - height_ratios=height_ratios, - wspace=0.05, hspace=0.05) - # create grid of axis numbers to easily create axes in the right locations - axes = numpy.arange(ndim**2).reshape((ndim, ndim)) - - # Select possible combinations of plots and establish rows and columns. - combos = list(itertools.combinations(parameters, 2)) - # add the diagonals - if not no_diagonals: - combos += [(p, p) for p in parameters] - - # create the mapping between parameter combos and axes - axis_dict = {} - # cycle over all the axes, setting thing as needed - for nrow in range(ndim): - for ncolumn in range(ndim): - ax = pyplot.subplot(gs[axes[nrow, ncolumn]]) - # map to a parameter index - px = parameters[ncolumn] - if no_diagonals: - py = parameters[nrow+1] - else: - py = parameters[nrow] - if (px, py) in combos: - axis_dict[px, py] = (ax, nrow, ncolumn) - # x labels only on bottom - if nrow + 1 == ndim: - ax.set_xlabel('{}'.format(labels[px]), fontsize=18) - else: - pyplot.setp(ax.get_xticklabels(), visible=False) - # y labels only on left - if ncolumn == 0: - ax.set_ylabel('{}'.format(labels[py]), fontsize=18) - else: - pyplot.setp(ax.get_yticklabels(), visible=False) - else: - # make non-used axes invisible - ax.axis('off') - return fig, axis_dict - - -def get_scale_fac(fig, fiducial_width=8, fiducial_height=7): - """Gets a factor to scale fonts by for the given figure. The scale - factor is relative to a figure with dimensions - (`fiducial_width`, `fiducial_height`). - """ - width, height = fig.get_size_inches() - return (width*height/(fiducial_width*fiducial_height))**0.5 - - -def construct_kde(samples_array, use_kombine=False, kdeargs=None): - """Constructs a KDE from the given samples. - - Parameters - ---------- - samples_array : array - Array of values to construct the KDE for. - use_kombine : bool, optional - Use kombine's clustered KDE instead of scipy's. Default is False. - kdeargs : dict, optional - Additional arguments to pass to the KDE. Can be any argument recognized - by :py:func:`scipy.stats.gaussian_kde` or - :py:func:`kombine.clustered_kde.optimized_kde`. In either case, you can - also set ``max_kde_samples`` to limit the number of samples that are - used for KDE construction. - - Returns - ------- - kde : - The KDE. - """ - # make sure samples are randomly sorted - numpy.random.seed(0) - numpy.random.shuffle(samples_array) - # if kde arg specifies a maximum number of samples, limit them - if kdeargs is None: - kdeargs = {} - else: - kdeargs = kdeargs.copy() - max_nsamples = kdeargs.pop('max_kde_samples', None) - samples_array = samples_array[:max_nsamples] - if use_kombine: - try: - import kombine - except ImportError: - raise ImportError("kombine is not installed.") - if kdeargs is None: - kdeargs = {} - # construct the kde - if use_kombine: - kde = kombine.clustered_kde.optimized_kde(samples_array, **kdeargs) - else: - kde = scipy.stats.gaussian_kde(samples_array.T, **kdeargs) - return kde - - -def create_density_plot(xparam, yparam, samples, plot_density=True, - plot_contours=True, percentiles=None, cmap='viridis', - contour_color=None, label_contours=True, - contour_linestyles=None, - xmin=None, xmax=None, - ymin=None, ymax=None, exclude_region=None, - fig=None, ax=None, use_kombine=False, - kdeargs=None): - """Computes and plots posterior density and confidence intervals using the - given samples. - - Parameters - ---------- - xparam : string - The parameter to plot on the x-axis. - yparam : string - The parameter to plot on the y-axis. - samples : dict, numpy structured array, or FieldArray - The samples to plot. - plot_density : {True, bool} - Plot a color map of the density. - plot_contours : {True, bool} - Plot contours showing the n-th percentiles of the density. - percentiles : {None, float or array} - What percentile contours to draw. If None, will plot the 50th - and 90th percentiles. - cmap : {'viridis', string} - The name of the colormap to use for the density plot. - contour_color : {None, string} - What color to make the contours. Default is white for density - plots and black for other plots. - label_contours : bool, optional - Whether to label the contours. Default is True. - contour_linestyles : list, optional - Linestyles to use for the contours. Default (None) will use solid. - xmin : {None, float} - Minimum value to plot on x-axis. - xmax : {None, float} - Maximum value to plot on x-axis. - ymin : {None, float} - Minimum value to plot on y-axis. - ymax : {None, float} - Maximum value to plot on y-axis. - exclue_region : {None, str} - Exclude the specified region when plotting the density or contours. - Must be a string in terms of `xparam` and `yparam` that is - understandable by numpy's logical evaluation. For example, if - `xparam = m_1` and `yparam = m_2`, and you want to exclude the region - for which `m_2` is greater than `m_1`, then exclude region should be - `'m_2 > m_1'`. - fig : {None, pyplot.figure} - Add the plot to the given figure. If None and ax is None, will create - a new figure. - ax : {None, pyplot.axes} - Draw plot on the given axis. If None, will create a new axis from - `fig`. - use_kombine : {False, bool} - Use kombine's KDE to calculate density. Otherwise, will use - `scipy.stats.gaussian_kde.` Default is False. - kdeargs : dict, optional - Pass the given keyword arguments to the KDE. - - Returns - ------- - fig : pyplot.figure - The figure the plot was made on. - ax : pyplot.axes - The axes the plot was drawn on. - """ - if percentiles is None: - percentiles = numpy.array([50., 90.]) - percentiles = 100. - numpy.array(percentiles) - percentiles.sort() - - if ax is None and fig is None: - fig = pyplot.figure() - if ax is None: - ax = fig.add_subplot(111) - - # convert samples to array and construct kde - xsamples = samples[xparam] - ysamples = samples[yparam] - arr = numpy.vstack((xsamples, ysamples)).T - kde = construct_kde(arr, use_kombine=use_kombine, kdeargs=kdeargs) - - # construct grid to evaluate on - if xmin is None: - xmin = xsamples.min() - if xmax is None: - xmax = xsamples.max() - if ymin is None: - ymin = ysamples.min() - if ymax is None: - ymax = ysamples.max() - npts = 100 - X, Y = numpy.mgrid[ - xmin:xmax:complex(0, npts), # pylint:disable=invalid-slice-index - ymin:ymax:complex(0, npts)] # pylint:disable=invalid-slice-index - pos = numpy.vstack([X.ravel(), Y.ravel()]) - if use_kombine: - Z = numpy.exp(kde(pos.T).reshape(X.shape)) - draw = kde.draw - else: - Z = kde(pos).T.reshape(X.shape) - draw = kde.resample - - if exclude_region is not None: - # convert X,Y to a single FieldArray so we can use it's ability to - # evaluate strings - farr = FieldArray.from_kwargs(**{xparam: X, yparam: Y}) - Z[farr[exclude_region]] = 0. - - if plot_density: - ax.imshow(numpy.rot90(Z), extent=[xmin, xmax, ymin, ymax], - aspect='auto', cmap=cmap, zorder=1) - if contour_color is None: - contour_color = 'w' - - if plot_contours: - # compute the percentile values - resamps = kde(draw(int(npts**2))) - if use_kombine: - resamps = numpy.exp(resamps) - s = numpy.percentile(resamps, percentiles) - if contour_color is None: - contour_color = 'k' - # make linewidths thicker if not plotting density for clarity - if plot_density: - lw = 1 - else: - lw = 2 - ct = ax.contour(X, Y, Z, s, colors=contour_color, linewidths=lw, - linestyles=contour_linestyles, zorder=3) - # label contours - if label_contours: - lbls = ['{p}%'.format(p=int(p)) for p in (100. - percentiles)] - fmt = dict(zip(ct.levels, lbls)) - fs = 12 - ax.clabel(ct, ct.levels, inline=True, fmt=fmt, fontsize=fs) - - return fig, ax - - -def create_marginalized_hist(ax, values, label, percentiles=None, - color='k', fillcolor='gray', linecolor='navy', - linestyle='-', plot_marginal_lines=True, - title=True, expected_value=None, - expected_color='red', rotated=False, - plot_min=None, plot_max=None): - """Plots a 1D marginalized histogram of the given param from the given - samples. - - Parameters - ---------- - ax : pyplot.Axes - The axes on which to draw the plot. - values : array - The parameter values to plot. - label : str - A label to use for the title. - percentiles : {None, float or array} - What percentiles to draw lines at. If None, will draw lines at - `[5, 50, 95]` (i.e., the bounds on the upper 90th percentile and the - median). - color : {'k', string} - What color to make the histogram; default is black. - fillcolor : {'gray', string, or None} - What color to fill the histogram with. Set to None to not fill the - histogram. Default is 'gray'. - plot_marginal_lines : bool, optional - Put vertical lines at the marginal percentiles. Default is True. - linestyle : str, optional - What line style to use for the histogram. Default is '-'. - linecolor : {'navy', string} - What color to use for the percentile lines. Default is 'navy'. - title : bool, optional - Add a title with a estimated value +/- uncertainty. The estimated value - is the pecentile halfway between the max/min of ``percentiles``, while - the uncertainty is given by the max/min of the ``percentiles``. If no - percentiles are specified, defaults to quoting the median +/- 95/5 - percentiles. - rotated : {False, bool} - Plot the histogram on the y-axis instead of the x. Default is False. - plot_min : {None, float} - The minimum value to plot. If None, will default to whatever `pyplot` - creates. - plot_max : {None, float} - The maximum value to plot. If None, will default to whatever `pyplot` - creates. - scalefac : {1., float} - Factor to scale the default font sizes by. Default is 1 (no scaling). - """ - if fillcolor is None: - htype = 'step' - else: - htype = 'stepfilled' - if rotated: - orientation = 'horizontal' - else: - orientation = 'vertical' - ax.hist(values, bins=50, histtype=htype, orientation=orientation, - facecolor=fillcolor, edgecolor=color, ls=linestyle, lw=2, - density=True) - if percentiles is None: - percentiles = [5., 50., 95.] - if len(percentiles) > 0: - plotp = numpy.percentile(values, percentiles) - else: - plotp = [] - if plot_marginal_lines: - for val in plotp: - if rotated: - ax.axhline(y=val, ls='dashed', color=linecolor, lw=2, zorder=3) - else: - ax.axvline(x=val, ls='dashed', color=linecolor, lw=2, zorder=3) - # plot expected - if expected_value is not None: - if rotated: - ax.axhline(expected_value, color=expected_color, lw=1.5, zorder=2) - else: - ax.axvline(expected_value, color=expected_color, lw=1.5, zorder=2) - if title: - if len(percentiles) > 0: - minp = min(percentiles) - maxp = max(percentiles) - medp = (maxp + minp) / 2. - else: - minp = 5 - medp = 50 - maxp = 95 - values_min = numpy.percentile(values, minp) - values_med = numpy.percentile(values, medp) - values_max = numpy.percentile(values, maxp) - negerror = values_med - values_min - poserror = values_max - values_med - fmt = '${0}$'.format(str_utils.format_value( - values_med, negerror, plus_error=poserror)) - if rotated: - ax.yaxis.set_label_position("right") - # sets colored title for marginal histogram - set_marginal_histogram_title(ax, fmt, color, - label=label, rotated=rotated) - else: - # sets colored title for marginal histogram - set_marginal_histogram_title(ax, fmt, color, label=label) - # remove ticks and set limits - if rotated: - # Remove x-ticks - ax.set_xticks([]) - # turn off x-labels - ax.set_xlabel('') - # set limits - ymin, ymax = ax.get_ylim() - if plot_min is not None: - ymin = plot_min - if plot_max is not None: - ymax = plot_max - ax.set_ylim(ymin, ymax) - else: - # Remove y-ticks - ax.set_yticks([]) - # turn off y-label - ax.set_ylabel('') - # set limits - xmin, xmax = ax.get_xlim() - if plot_min is not None: - xmin = plot_min - if plot_max is not None: - xmax = plot_max - ax.set_xlim(xmin, xmax) - - -def set_marginal_histogram_title(ax, fmt, color, label=None, rotated=False): - """ Sets the title of the marginal histograms. - - Parameters - ---------- - ax : Axes - The `Axes` instance for the plot. - fmt : str - The string to add to the title. - color : str - The color of the text to add to the title. - label : str - If title does not exist, then include label at beginning of the string. - rotated : bool - If `True` then rotate the text 270 degrees for sideways title. - """ - - # get rotation angle of the title - rotation = 270 if rotated else 0 - - # get how much to displace title on axes - xscale = 1.05 if rotated else 0.0 - if rotated: - yscale = 1.0 - elif len(ax.get_figure().axes) > 1: - yscale = 1.15 - else: - yscale = 1.05 - - # get class that packs text boxes vertical or horizonitally - packer_class = offsetbox.VPacker if rotated else offsetbox.HPacker - - # if no title exists - if not hasattr(ax, "title_boxes"): - - # create a text box - title = "{} = {}".format(label, fmt) - tbox1 = offsetbox.TextArea( - title, - textprops=dict(color=color, size=15, rotation=rotation, - ha='left', va='bottom')) - - # save a list of text boxes as attribute for later - ax.title_boxes = [tbox1] - - # pack text boxes - ybox = packer_class(children=ax.title_boxes, - align="bottom", pad=0, sep=5) - - # else append existing title - else: - - # delete old title - ax.title_anchor.remove() - - # add new text box to list - tbox1 = offsetbox.TextArea( - " {}".format(fmt), - textprops=dict(color=color, size=15, rotation=rotation, - ha='left', va='bottom')) - ax.title_boxes = ax.title_boxes + [tbox1] - - # pack text boxes - ybox = packer_class(children=ax.title_boxes, - align="bottom", pad=0, sep=5) - - # add new title and keep reference to instance as an attribute - anchored_ybox = offsetbox.AnchoredOffsetbox( - loc=2, child=ybox, pad=0., - frameon=False, bbox_to_anchor=(xscale, yscale), - bbox_transform=ax.transAxes, borderpad=0.) - ax.title_anchor = ax.add_artist(anchored_ybox) - - -def create_multidim_plot(parameters, samples, labels=None, - mins=None, maxs=None, expected_parameters=None, - expected_parameters_color='r', - plot_marginal=True, plot_scatter=True, - plot_maxl=False, - plot_marginal_lines=True, - marginal_percentiles=None, contour_percentiles=None, - marginal_title=True, marginal_linestyle='-', - zvals=None, show_colorbar=True, cbar_label=None, - vmin=None, vmax=None, scatter_cmap='plasma', - plot_density=False, plot_contours=True, - density_cmap='viridis', - contour_color=None, label_contours=True, - contour_linestyles=None, - hist_color='black', - line_color=None, fill_color='gray', - use_kombine=False, kdeargs=None, - fig=None, axis_dict=None): - """Generate a figure with several plots and histograms. - - Parameters - ---------- - parameters: list - Names of the variables to be plotted. - samples : FieldArray - A field array of the samples to plot. - labels: dict, optional - A dictionary mapping parameters to labels. If none provided, will just - use the parameter strings as the labels. - mins : {None, dict}, optional - Minimum value for the axis of each variable in `parameters`. - If None, it will use the minimum of the corresponding variable in - `samples`. - maxs : {None, dict}, optional - Maximum value for the axis of each variable in `parameters`. - If None, it will use the maximum of the corresponding variable in - `samples`. - expected_parameters : {None, dict}, optional - Expected values of `parameters`, as a dictionary mapping parameter - names -> values. A cross will be plotted at the location of the - expected parameters on axes that plot any of the expected parameters. - expected_parameters_color : {'r', string}, optional - What color to make the expected parameters cross. - plot_marginal : {True, bool} - Plot the marginalized distribution on the diagonals. If False, the - diagonal axes will be turned off. - plot_scatter : {True, bool} - Plot each sample point as a scatter plot. - marginal_percentiles : {None, array} - What percentiles to draw lines at on the 1D histograms. - If None, will draw lines at `[5, 50, 95]` (i.e., the bounds on the - upper 90th percentile and the median). - marginal_title : bool, optional - Add a title over the 1D marginal plots that gives an estimated value - +/- uncertainty. The estimated value is the pecentile halfway between - the max/min of ``maginal_percentiles``, while the uncertainty is given - by the max/min of the ``marginal_percentiles. If no - ``marginal_percentiles`` are specified, the median +/- 95/5 percentiles - will be quoted. - marginal_linestyle : str, optional - What line style to use for the marginal histograms. - contour_percentiles : {None, array} - What percentile contours to draw on the scatter plots. If None, - will plot the 50th and 90th percentiles. - zvals : {None, array} - An array to use for coloring the scatter plots. If None, scatter points - will be the same color. - show_colorbar : {True, bool} - Show the colorbar of zvalues used for the scatter points. A ValueError - will be raised if zvals is None and this is True. - cbar_label : {None, str} - Specify a label to add to the colorbar. - vmin: {None, float}, optional - Minimum value for the colorbar. If None, will use the minimum of zvals. - vmax: {None, float}, optional - Maximum value for the colorbar. If None, will use the maxmimum of - zvals. - scatter_cmap : {'plasma', string} - The color map to use for the scatter points. Default is 'plasma'. - plot_density : {False, bool} - Plot the density of points as a color map. - plot_contours : {True, bool} - Draw contours showing the 50th and 90th percentile confidence regions. - density_cmap : {'viridis', string} - The color map to use for the density plot. - contour_color : {None, string} - The color to use for the contour lines. Defaults to white for - density plots, navy for scatter plots without zvals, and black - otherwise. - label_contours : bool, optional - Whether to label the contours. Default is True. - contour_linestyles : list, optional - Linestyles to use for the contours. Default (None) will use solid. - use_kombine : {False, bool} - Use kombine's KDE to calculate density. Otherwise, will use - `scipy.stats.gaussian_kde.` Default is False. - kdeargs : dict, optional - Pass the given keyword arguments to the KDE. - fig : pyplot.figure - Use the given figure instead of creating one. - axis_dict : dict - Use the given dictionary of axes instead of creating one. - - Returns - ------- - fig : pyplot.figure - The figure that was created. - axis_dict : dict - A dictionary mapping the parameter combinations to the axis and their - location in the subplots grid; i.e., the key, values are: - `{('param1', 'param2'): (pyplot.axes, row index, column index)}` - """ - if labels is None: - labels = {p: p for p in parameters} - # set up the figure with a grid of axes - # if only plotting 2 parameters, make the marginal plots smaller - nparams = len(parameters) - if nparams == 2: - width_ratios = [3, 1] - height_ratios = [1, 3] - else: - width_ratios = height_ratios = None - - if plot_maxl: - # make sure loglikelihood is provide - if 'loglikelihood' not in samples.fieldnames: - raise ValueError("plot-maxl requires loglikelihood") - maxidx = samples['loglikelihood'].argmax() - - # only plot scatter if more than one parameter - plot_scatter = plot_scatter and nparams > 1 - - # Sort zvals to get higher values on top in scatter plots - if plot_scatter: - if zvals is not None: - sort_indices = zvals.argsort() - zvals = zvals[sort_indices] - samples = samples[sort_indices] - if contour_color is None: - contour_color = 'k' - elif show_colorbar: - raise ValueError("must provide z values to create a colorbar") - else: - # just make all scatter points same color - zvals = 'gray' - if plot_contours and contour_color is None: - contour_color = 'navy' - - # create the axis grid - if fig is None and axis_dict is None: - fig, axis_dict = create_axes_grid( - parameters, labels=labels, - width_ratios=width_ratios, height_ratios=height_ratios, - no_diagonals=not plot_marginal) - - # convert samples to a dictionary to avoid re-computing derived parameters - # every time they are needed - # only try to plot what's available - sd = {} - for p in parameters: - try: - sd[p] = samples[p] - except (ValueError, TypeError, IndexError): - continue - samples = sd - parameters = list(sd.keys()) - - # values for axis bounds - if mins is None: - mins = {p: samples[p].min() for p in parameters} - else: - # copy the dict - mins = {p: val for p, val in mins.items()} - if maxs is None: - maxs = {p: samples[p].max() for p in parameters} - else: - # copy the dict - maxs = {p: val for p, val in maxs.items()} - - # Diagonals... - if plot_marginal: - for pi, param in enumerate(parameters): - ax, _, _ = axis_dict[param, param] - # if only plotting 2 parameters and on the second parameter, - # rotate the marginal plot - rotated = nparams == 2 and pi == nparams-1 - # see if there are expected values - if expected_parameters is not None: - try: - expected_value = expected_parameters[param] - except KeyError: - expected_value = None - else: - expected_value = None - create_marginalized_hist( - ax, samples[param], label=labels[param], - color=hist_color, fillcolor=fill_color, - plot_marginal_lines=plot_marginal_lines, - linestyle=marginal_linestyle, linecolor=line_color, - title=marginal_title, expected_value=expected_value, - expected_color=expected_parameters_color, - rotated=rotated, plot_min=mins[param], plot_max=maxs[param], - percentiles=marginal_percentiles) - - # Off-diagonals... - for px, py in axis_dict: - if px == py or px not in parameters or py not in parameters: - continue - ax, _, _ = axis_dict[px, py] - if plot_scatter: - if plot_density: - alpha = 0.3 - else: - alpha = 1. - plt = ax.scatter(x=samples[px], y=samples[py], c=zvals, s=5, - edgecolors='none', vmin=vmin, vmax=vmax, - cmap=scatter_cmap, alpha=alpha, zorder=2) - - if plot_contours or plot_density: - # Exclude out-of-bound regions - # this is a bit kludgy; should probably figure out a better - # solution to eventually allow for more than just m_p m_s - if (px == 'm_p' and py == 'm_s') or (py == 'm_p' and px == 'm_s'): - exclude_region = 'm_s > m_p' - else: - exclude_region = None - create_density_plot( - px, py, samples, plot_density=plot_density, - plot_contours=plot_contours, cmap=density_cmap, - percentiles=contour_percentiles, - contour_color=contour_color, label_contours=label_contours, - contour_linestyles=contour_linestyles, - xmin=mins[px], xmax=maxs[px], - ymin=mins[py], ymax=maxs[py], - exclude_region=exclude_region, ax=ax, - use_kombine=use_kombine, kdeargs=kdeargs) - - if plot_maxl: - maxlx = samples[px][maxidx] - maxly = samples[py][maxidx] - ax.scatter(maxlx, maxly, marker='x', s=20, c=contour_color, - zorder=5) - - if expected_parameters is not None: - try: - ax.axvline(expected_parameters[px], lw=1.5, - color=expected_parameters_color, zorder=5) - except KeyError: - pass - try: - ax.axhline(expected_parameters[py], lw=1.5, - color=expected_parameters_color, zorder=5) - except KeyError: - pass - - ax.set_xlim(mins[px], maxs[px]) - ax.set_ylim(mins[py], maxs[py]) - - # adjust tick number for large number of plots - if len(parameters) > 3: - for px, py in axis_dict: - ax, _, _ = axis_dict[px, py] - ax.set_xticks(reduce_ticks(ax, 'x', maxticks=3)) - ax.set_yticks(reduce_ticks(ax, 'y', maxticks=3)) - - if plot_scatter and show_colorbar: - # compute font size based on fig size - scale_fac = get_scale_fac(fig) - fig.subplots_adjust(right=0.85, wspace=0.03) - cbar_ax = fig.add_axes([0.9, 0.1, 0.03, 0.8]) - cb = fig.colorbar(plt, cax=cbar_ax) - if cbar_label is not None: - cb.set_label(cbar_label, fontsize=12*scale_fac) - cb.ax.tick_params(labelsize=8*scale_fac) - - return fig, axis_dict - - -def remove_common_offset(arr): - """Given an array of data, removes a common offset > 1000, returning the - removed value. - """ - offset = 0 - isneg = (arr <= 0).all() - # make sure all values have the same sign - if isneg or (arr >= 0).all(): - # only remove offset if the minimum and maximum values are the same - # order of magintude and > O(1000) - minpwr = numpy.log10(abs(arr).min()) - maxpwr = numpy.log10(abs(arr).max()) - if numpy.floor(minpwr) == numpy.floor(maxpwr) and minpwr > 3: - offset = numpy.floor(10**minpwr) - if isneg: - offset *= -1 - arr = arr - offset - return arr, int(offset) - - -def reduce_ticks(ax, which, maxticks=3): - """Given a pyplot axis, resamples its `which`-axis ticks such that are at most - `maxticks` left. - - Parameters - ---------- - ax : axis - The axis to adjust. - which : {'x' | 'y'} - Which axis to adjust. - maxticks : {3, int} - Maximum number of ticks to use. - - Returns - ------- - array - An array of the selected ticks. - """ - ticks = getattr(ax, 'get_{}ticks'.format(which))() - if len(ticks) > maxticks: - # make sure the left/right value is not at the edge - minax, maxax = getattr(ax, 'get_{}lim'.format(which))() - dw = abs(maxax-minax)/10. - start_idx, end_idx = 0, len(ticks) - if ticks[0] < minax + dw: - start_idx += 1 - if ticks[-1] > maxax - dw: - end_idx -= 1 - # get reduction factor - fac = int(len(ticks) / maxticks) - ticks = ticks[start_idx:end_idx:fac] - return ticks diff --git a/pycbc/strain/.ipynb_checkpoints/gate-checkpoint.py b/pycbc/strain/.ipynb_checkpoints/gate-checkpoint.py deleted file mode 100644 index e294beeae87..00000000000 --- a/pycbc/strain/.ipynb_checkpoints/gate-checkpoint.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (C) 2016 Collin Capano -# -# This program is free software; you can redistribute it and/or modify it -# under the terms of the GNU General Public License as published by the -# Free Software Foundation; either version 3 of the License, or (at your -# option) any later version. -# -# This program is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General -# Public License for more details. -# -# You should have received a copy of the GNU General Public License along -# with this program; if not, write to the Free Software Foundation, Inc., -# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -""" Functions for applying gates to data. -""" - -from scipy import linalg -from . import strain - - -def _gates_from_cli(opts, gate_opt): - """Parses the given `gate_opt` into something understandable by - `strain.gate_data`. - """ - gates = {} - if getattr(opts, gate_opt) is None: - return gates - for gate in getattr(opts, gate_opt): - try: - ifo, central_time, half_dur, taper_dur = gate.split(':') - central_time = float(central_time) - half_dur = float(half_dur) - taper_dur = float(taper_dur) - except ValueError: - raise ValueError("--gate {} not formatted correctly; ".format( - gate) + "see help") - try: - gates[ifo].append((central_time, half_dur, taper_dur)) - except KeyError: - gates[ifo] = [(central_time, half_dur, taper_dur)] - return gates - - -def gates_from_cli(opts): - """Parses the --gate option into something understandable by - `strain.gate_data`. - """ - return _gates_from_cli(opts, 'gate') - - -def psd_gates_from_cli(opts): - """Parses the --psd-gate option into something understandable by - `strain.gate_data`. - """ - return _gates_from_cli(opts, 'psd_gate') - - -def apply_gates_to_td(strain_dict, gates): - """Applies the given dictionary of gates to the given dictionary of - strain. - - Parameters - ---------- - strain_dict : dict - Dictionary of time-domain strain, keyed by the ifos. - gates : dict - Dictionary of gates. Keys should be the ifo to apply the data to, - values are a tuple giving the central time of the gate, the half - duration, and the taper duration. - - Returns - ------- - dict - Dictionary of time-domain strain with the gates applied. - """ - # copy data to new dictionary - outdict = dict(strain_dict.items()) - for ifo in gates: - outdict[ifo] = strain.gate_data(outdict[ifo], gates[ifo]) - return outdict - - -def apply_gates_to_fd(stilde_dict, gates): - """Applies the given dictionary of gates to the given dictionary of - strain in the frequency domain. - - Gates are applied by IFFT-ing the strain data to the time domain, applying - the gate, then FFT-ing back to the frequency domain. - - Parameters - ---------- - stilde_dict : dict - Dictionary of frequency-domain strain, keyed by the ifos. - gates : dict - Dictionary of gates. Keys should be the ifo to apply the data to, - values are a tuple giving the central time of the gate, the half - duration, and the taper duration. - - Returns - ------- - dict - Dictionary of frequency-domain strain with the gates applied. - """ - # copy data to new dictionary - outdict = dict(stilde_dict.items()) - # create a time-domin strain dictionary to apply the gates to - strain_dict = dict([[ifo, outdict[ifo].to_timeseries()] for ifo in gates]) - # apply gates and fft back to the frequency domain - for ifo,d in apply_gates_to_td(strain_dict, gates).items(): - outdict[ifo] = d.to_frequencyseries() - return outdict - - -def add_gate_option_group(parser): - """Adds the options needed to apply gates to data. - - Parameters - ---------- - parser : object - ArgumentParser instance. - """ - gate_group = parser.add_argument_group("Options for gating data") - - gate_group.add_argument("--gate", nargs="+", type=str, - metavar="IFO:CENTRALTIME:HALFDUR:TAPERDUR", - help="Apply one or more gates to the data before " - "filtering.") - gate_group.add_argument("--gate-overwhitened", action="store_true", - help="Overwhiten data first, then apply the " - "gates specified in --gate. Overwhitening " - "allows for sharper tapers to be used, " - "since lines are not blurred.") - gate_group.add_argument("--psd-gate", nargs="+", type=str, - metavar="IFO:CENTRALTIME:HALFDUR:TAPERDUR", - help="Apply one or more gates to the data used " - "for computing the PSD. Gates are applied " - "prior to FFT-ing the data for PSD " - "estimation.") - return gate_group - - -def gate_and_paint(data, lindex, rindex, invpsd, copy=True): - """Gates and in-paints data. - - Parameters - ---------- - data : TimeSeries - The data to gate. - lindex : int - The start index of the gate. - rindex : int - The end index of the gate. - invpsd : FrequencySeries - The inverse of the PSD. - copy : bool, optional - Copy the data before applying the gate. Otherwise, the gate will - be applied in-place. Default is True. - - Returns - ------- - TimeSeries : - The gated and in-painted time series. - """ - # Uses the hole-filling method of - # https://arxiv.org/pdf/1908.05644.pdf - # Copy the data and zero inside the hole - if copy: - data = data.copy() - data[lindex:rindex] = 0 - - # get the over-whitened gated data - tdfilter = invpsd.astype('complex').to_timeseries() * invpsd.delta_t - owhgated_data = (data.to_frequencyseries() * invpsd).to_timeseries() - - # remove the projection into the null space - proj = linalg.solve_toeplitz(tdfilter[:(rindex - lindex)], - owhgated_data[lindex:rindex]) - data[lindex:rindex] -= proj - data.projslc = (lindex, rindex) - data.proj = proj - return data