Skip to content

Commit

Permalink
Standardize some common sampler functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ColmTalbot committed Aug 10, 2022
1 parent 907dbf6 commit d7df754
Show file tree
Hide file tree
Showing 29 changed files with 2,801 additions and 2,389 deletions.
3 changes: 2 additions & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ python-3.9:
stage: test
script:
- python -m pip install .
- python -m pip install schwimmbad
- python -m pip list installed

- pytest test/integration/sampler_run_test.py --durations 10
- pytest test/integration/sampler_run_test.py --durations 10 -v

python-3.8-samplers:
<<: *test-sampler
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
hooks:
- id: black
language_version: python3
files: '(^bilby/bilby_mcmc/|^examples/)'
files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
Expand All @@ -20,7 +20,7 @@ repos:
hooks:
- id: isort # sort imports alphabetically and separates import into sections
args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
files: '(^bilby/bilby_mcmc/|^examples/)'
files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
- repo: https://github.com/datarootsio/databooks
rev: 0.1.14
hooks:
Expand Down
139 changes: 30 additions & 109 deletions bilby/bilby_mcmc/sampler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import datetime
import os
import signal
import time
from collections import Counter

import numpy as np
import pandas as pd

from ..core.result import rejection_sample
from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError
from ..core.sampler.base_sampler import (
MCMCSampler,
ResumeError,
SamplerError,
_sampling_convenience_dump,
signal_wrapper,
)
from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
from . import proposals
from .chain import Chain, Sample
Expand Down Expand Up @@ -131,7 +136,6 @@ class Bilby_MCMC(MCMCSampler):
autocorr_c=5,
L1steps=100,
L2steps=3,
npool=1,
printdt=60,
min_tau=1,
proposal_cycle="default",
Expand Down Expand Up @@ -172,7 +176,6 @@ def __init__(
self.check_point_plot = check_point_plot
self.diagnostic = diagnostic
self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
self.npool = self.kwargs["npool"]
self.L1steps = self.kwargs["L1steps"]
self.L2steps = self.kwargs["L2steps"]
self.pt_inputs = ParallelTemperingInputs(
Expand All @@ -194,17 +197,6 @@ def __init__(
self.verify_configuration()
self.verbose = verbose

try:
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
except AttributeError:
logger.debug(
"Setting signal attributes unavailable on this system. "
"This is likely the case if you are running on a Windows machine"
" and is no further concern."
)

def verify_configuration(self):
if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
logger.warning("Burn-in inefficiency fraction greater than 10%")
Expand All @@ -223,6 +215,7 @@ def _translate_kwargs(self, kwargs):
def target_nsamples(self):
return self.kwargs["target_nsamples"]

@signal_wrapper
def run_sampler(self):
self._setup_pool()
self.setup_chain_set()
Expand Down Expand Up @@ -377,31 +370,12 @@ def read_current_state(self):
f"setup:\n{self.get_setup_string()}"
)

def write_current_state_and_exit(self, signum=None, frame=None):
"""
Make sure that if a pool of jobs is running only the parent tries to
checkpoint and exit. Only the parent has a 'pool' attribute.
"""
if self.npool == 1 or getattr(self, "pool", None) is not None:
if signum == 14:
logger.info(
"Run interrupted by alarm signal {}: checkpoint and exit on {}".format(
signum, self.exit_code
)
)
else:
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}".format(
signum, self.exit_code
)
)
self.write_current_state()
self._close_pool()
os._exit(self.exit_code)

def write_current_state(self):
import dill

if not hasattr(self, "ptsampler"):
logger.debug("Attempted checkpoint before initialization")
return
logger.debug("Check point")
check_directory_exists_and_if_not_mkdir(self.outdir)

Expand Down Expand Up @@ -534,39 +508,6 @@ def plot_progress(ptsampler, label, outdir, priors, diagnostic=False):
all_samples=ptsampler.samples,
)

def _setup_pool(self):
if self.npool > 1:
logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
import multiprocessing

self.pool = multiprocessing.Pool(
processes=self.npool,
initializer=_initialize_global_variables,
initargs=(
self.likelihood,
self.priors,
self._search_parameter_keys,
self.use_ratio,
),
)
else:
self.pool = None

_initialize_global_variables(
likelihood=self.likelihood,
priors=self.priors,
search_parameter_keys=self._search_parameter_keys,
use_ratio=self.use_ratio,
)

def _close_pool(self):
if getattr(self, "pool", None) is not None:
logger.info("Starting to close worker pool.")
self.pool.close()
self.pool.join()
self.pool = None
logger.info("Finished closing worker pool.")


class BilbyPTMCMCSampler(object):
def __init__(
Expand All @@ -579,7 +520,6 @@ def __init__(
use_ratio,
evidence_method,
):

self.set_pt_inputs(pt_inputs)
self.use_ratio = use_ratio
self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
Expand All @@ -597,7 +537,7 @@ def __init__(

self._nsamples_dict = {}
self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
_priors
_sampling_convenience_dump.priors
)
self.sampling_time = 0
self.ln_z_dict = dict()
Expand All @@ -612,7 +552,7 @@ def get_initial_betas(self):
elif pt_inputs.Tmax is not None:
betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
elif pt_inputs.Tmax_from_SNR is not None:
ndim = len(_priors.non_fixed_keys)
ndim = len(_sampling_convenience_dump.priors.non_fixed_keys)
target_hot_likelihood = ndim / 2
Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood)
betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
Expand Down Expand Up @@ -1140,12 +1080,14 @@ def __init__(
self.Eindex = Eindex
self.use_ratio = use_ratio

self.parameters = _priors.non_fixed_keys
self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
self.ndim = len(self.parameters)

full_sample_dict = _priors.sample()
full_sample_dict = _sampling_convenience_dump.priors.sample()
initial_sample = {
k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys
k: v
for k, v in full_sample_dict.items()
if k in _sampling_convenience_dump.priors.non_fixed_keys
}
initial_sample = Sample(initial_sample)
initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
Expand All @@ -1168,7 +1110,10 @@ def __init__(
warn = False

self.proposal_cycle = proposals.get_proposal_cycle(
proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn
proposal_cycle,
_sampling_convenience_dump.priors,
L1steps=self.chain.L1steps,
warn=warn,
)
elif isinstance(proposal_cycle, proposals.ProposalCycle):
self.proposal_cycle = proposal_cycle
Expand All @@ -1185,17 +1130,17 @@ def set_convergence_inputs(self, convergence_inputs):
self.stop_after_convergence = convergence_inputs.stop_after_convergence

def log_likelihood(self, sample):
_likelihood.parameters.update(sample.sample_dict)
_sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict)

if self.use_ratio:
logl = _likelihood.log_likelihood_ratio()
logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio()
else:
logl = _likelihood.log_likelihood()
logl = _sampling_convenience_dump.likelihood.log_likelihood()

return logl

def log_prior(self, sample):
return _priors.ln_prob(sample.parameter_only_dict)
return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict)

def accept_proposal(self, prop, proposal):
self.chain.append(prop)
Expand Down Expand Up @@ -1293,8 +1238,10 @@ def rejection_sample_zero_temperature_samples(self, print_message=False):
zerotemp_logl = hot_samples[LOGLKEY]

# Revert to true likelihood if needed
if _use_ratio:
zerotemp_logl += _likelihood.noise_log_likelihood()
if _sampling_convenience_dump.use_ratio:
zerotemp_logl += (
_sampling_convenience_dump.likelihood.noise_log_likelihood()
)

# Calculate normalised weights
log_weights = (1 - beta) * zerotemp_logl
Expand Down Expand Up @@ -1322,29 +1269,3 @@ def rejection_sample_zero_temperature_samples(self, print_message=False):
def call_step(sampler):
sampler = sampler.step()
return sampler


_likelihood = None
_priors = None
_search_parameter_keys = None
_use_ratio = False


def _initialize_global_variables(
likelihood,
priors,
search_parameter_keys,
use_ratio,
):
"""
Store a global copy of the likelihood, priors, and search keys for
multiprocessing.
"""
global _likelihood
global _priors
global _search_parameter_keys
global _use_ratio
_likelihood = likelihood
_priors = priors
_search_parameter_keys = search_parameter_keys
_use_ratio = use_ratio
Loading

0 comments on commit d7df754

Please sign in to comment.