Skip to content

Commit

Permalink
DEV: update the nessai interface for v0.7.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will authored and ColmTalbot committed Nov 16, 2022
1 parent c65f9f1 commit 2d1c9e9
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 85 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ as requested in their associated documentation.
* `pymultinest <https://github.com/JohannesBuchner/PyMultiNest>`__
* `cpnest <https://github.com/johnveitch/cpnest>`__
* `emcee <https://github.com/dfm/emcee>`__
* `nessai <https://github.com/mj-will/nessai>`_
* `ptemcee <https://github.com/willvousden/ptemcee>`__
* `ptmcmcsampler <https://github.com/jellis18/PTMCMCSampler>`__
* `pypolychord <https://github.com/PolyChord/PolyChordLite>`__
Expand Down
227 changes: 162 additions & 65 deletions bilby/core/sampler/nessai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import sys

import numpy as np
from pandas import DataFrame
from scipy.special import logsumexp

from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger
from .base_sampler import NestedSampler
from .base_sampler import NestedSampler, signal_wrapper


class Nessai(NestedSampler):
Expand All @@ -19,41 +21,61 @@ class Nessai(NestedSampler):
"""

_default_kwargs = None
_run_kwargs_list = None
sampling_seed_key = "seed"

@property
def run_kwargs_list(self):
"""List of kwargs used in the run method of :code:`FlowSampler`"""
if not self._run_kwargs_list:
from nessai.utils.bilbyutils import get_run_kwargs_list

self._run_kwargs_list = get_run_kwargs_list()
ignored_kwargs = ["save"]
for ik in ignored_kwargs:
if ik in self._run_kwargs_list:
self._run_kwargs_list.remove(ik)
return self._run_kwargs_list

@property
def default_kwargs(self):
"""Default kwargs for nessai.
Retrieves default values from nessai directly and then includes any
bilby specific defaults. This avoids the need to update bilby when the
defaults change or new kwargs are added to nessai.
Includes the following kwargs that are specific to bilby:
- :code:`nessai_log_level`: allows setting the logging level in nessai
- :code:`nessai_logging_stream`: allows setting the logging stream
- :code:`nessai_plot`: allows toggling the plotting in FlowSampler.run
"""
if not self._default_kwargs:
from inspect import signature

from nessai.flowsampler import FlowSampler
from nessai.nestedsampler import NestedSampler
from nessai.proposal import AugmentedFlowProposal, FlowProposal

kwargs = {}
classes = [
AugmentedFlowProposal,
FlowProposal,
NestedSampler,
FlowSampler,
]
for c in classes:
kwargs.update(
{
k: v.default
for k, v in signature(c).parameters.items()
if v.default is not v.empty
}
)
from nessai.utils.bilbyutils import get_all_kwargs

kwargs = get_all_kwargs()

# Defaults for bilby that will override nessai defaults
bilby_defaults = dict(output=None, exit_code=self.exit_code)
bilby_defaults = dict(
output=None,
exit_code=self.exit_code,
nessai_log_level=None,
nessai_logging_stream="stdout",
nessai_plot=True,
plot_posterior=False, # bilby already produces a posterior plot
log_on_iteration=False, # Use periodic logging by default
logging_interval=60, # Log every 60 seconds
)
kwargs.update(bilby_defaults)
# Kwargs that cannot be set in bilby
remove = [
"save",
"signal_handling",
]
for k in remove:
if k in kwargs:
kwargs.pop(k)
self._default_kwargs = kwargs
return self._default_kwargs

Expand All @@ -72,12 +94,10 @@ def log_prior(self, theta):
"""
return self.priors.ln_prob(theta, axis=0)

def run_sampler(self):
from nessai.flowsampler import FlowSampler
from nessai.livepoint import dict_to_live_points, live_points_to_array
def get_nessai_model(self):
"""Get the model for nessai."""
from nessai.livepoint import dict_to_live_points
from nessai.model import Model as BaseModel
from nessai.posterior import compute_weights
from nessai.utils import setup_logger

class Model(BaseModel):
"""A wrapper class to pass our log_likelihood and priors into nessai
Expand Down Expand Up @@ -124,47 +144,115 @@ def new_point_log_prob(self, x):
"""Proposal probability for new the point"""
return self.log_prior(x)

# Setup the logger for nessai using the same settings as the bilby logger
setup_logger(
self.outdir, label=self.label, log_level=logger.getEffectiveLevel()
)
@staticmethod
def from_unit_hypercube(x):
"""Map samples from the unit hypercube to the prior."""
theta = {}
for n in self._search_parameter_keys:
theta[n] = self.priors[n].rescale(x[n])
return dict_to_live_points(theta)

@staticmethod
def to_unit_hypercube(x):
"""Map samples from the prior to the unit hypercube."""
theta = {n: x[n] for n in self._search_parameter_keys}
return dict_to_live_points(self.priors.cdf(theta))

model = Model(self.search_parameter_keys, self.priors)
try:
out = FlowSampler(model, **self.kwargs)
out.run(save=True, plot=self.plot)
except TypeError as e:
raise TypeError(f"Unable to initialise nessai sampler with error: {e}")
except (SystemExit, KeyboardInterrupt) as e:
import sys

logger.info(
f"Caught {type(e).__name__} with args {e.args}, "
f"exiting with signal {self.exit_code}"
)
sys.exit(self.exit_code)
return model

def split_kwargs(self):
"""Split kwargs into configuration and run time kwargs"""
kwargs = self.kwargs.copy()
run_kwargs = {}
for k in self.run_kwargs_list:
run_kwargs[k] = kwargs.pop(k)
run_kwargs["plot"] = kwargs.pop("nessai_plot")
return kwargs, run_kwargs

def get_posterior_weights(self):
"""Get the posterior weights for the nested samples"""
from nessai.posterior import compute_weights

_, log_weights = compute_weights(
np.array(self.fs.nested_samples["logL"]),
np.array(self.fs.ns.state.nlive),
)
w = np.exp(log_weights - logsumexp(log_weights))
return w

def get_nested_samples(self):
"""Get the nested samples dataframe"""
ns = DataFrame(self.fs.nested_samples)
ns.rename(
columns=dict(logL="log_likelihood", logP="log_prior", it="iteration"),
inplace=True,
)
return ns

def update_result(self):
"""Update the result object."""
from nessai.livepoint import live_points_to_array

# Manually set likelihood evaluations because parallelisation breaks the counter
self.result.num_likelihood_evaluations = out.ns.likelihood_evaluations[-1]
self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations

self.result.samples = live_points_to_array(
out.posterior_samples, self.search_parameter_keys
self.fs.posterior_samples, self.search_parameter_keys
)
self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
self.result.nested_samples = DataFrame(out.nested_samples)
self.result.nested_samples.rename(
columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True
self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"]
self.result.nested_samples = self.get_nested_samples()
self.result.nested_samples["weights"] = self.get_posterior_weights()
self.result.log_evidence = self.fs.log_evidence
self.result.log_evidence_err = self.fs.log_evidence_error

@signal_wrapper
def run_sampler(self):
"""Run the sampler.
Nessai is designed to be ran in two stages, initialise the sampler
and then call the run method with additional configuration. This means
there are effectively two sets of keyword arguments: one for
initializing the sampler and the other for the run function.
"""
from nessai.flowsampler import FlowSampler
from nessai.utils import setup_logger

kwargs, run_kwargs = self.split_kwargs()

# Setup the logger for nessai, use nessai_log_level if specified, else use
# the level of the bilby logger.
nessai_log_level = kwargs.pop("nessai_log_level")
if nessai_log_level is None or nessai_log_level == "bilby":
nessai_log_level = logger.getEffectiveLevel()
nessai_logging_stream = kwargs.pop("nessai_logging_stream")

setup_logger(
self.outdir,
label=self.label,
log_level=nessai_log_level,
stream=nessai_logging_stream,
)
_, log_weights = compute_weights(
np.array(self.result.nested_samples.log_likelihood),
np.array(out.ns.state.nlive),

# Get the nessai model
model = self.get_nessai_model()

# Configure the sampler
self.fs = FlowSampler(
model,
signal_handling=False, # Disable signal handling so it can be handled by bilby
**kwargs,
)
self.result.nested_samples["weights"] = np.exp(log_weights)
self.result.log_evidence = out.ns.log_evidence
self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive)
# Run the sampler
self.fs.run(**run_kwargs)

# Update the result
self.update_result()

return self.result

def _translate_kwargs(self, kwargs):
"""Translate the keyword arguments"""
super()._translate_kwargs(kwargs)
if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs:
Expand All @@ -178,10 +266,7 @@ def _translate_kwargs(self, kwargs):
kwargs["n_pool"] = self._npool

def _verify_kwargs_against_default_kwargs(self):
"""
Set the directory where the output will be written
and check resume and checkpoint status.
"""
"""Verify the keyword arguments"""
if "config_file" in self.kwargs:
d = load_json(self.kwargs["config_file"], None)
self.kwargs.update(d)
Expand All @@ -190,10 +275,6 @@ def _verify_kwargs_against_default_kwargs(self):
if not self.kwargs["plot"]:
self.kwargs["plot"] = self.plot

if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1:
logger.warning("Setting pool to None (n_pool=1 & max_threads=1)")
self.kwargs["n_pool"] = None

if not self.kwargs["output"]:
self.kwargs["output"] = os.path.join(
self.outdir, f"{self.label}_nessai", ""
Expand All @@ -202,5 +283,21 @@ def _verify_kwargs_against_default_kwargs(self):
check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
NestedSampler._verify_kwargs_against_default_kwargs(self)

def write_current_state(self):
"""Write the current state of the sampler"""
self.fs.ns.checkpoint()

def write_current_state_and_exit(self, signum=None, frame=None):
"""
Overwrites the base class to make sure that :code:`Nessai` terminates
properly.
"""
if hasattr(self, "fs"):
self.fs.terminate_run(code=signum)
else:
logger.warning("Sampler is not initialized")
self._log_interruption(signum=signum)
sys.exit(self.exit_code)

def _setup_pool(self):
pass
2 changes: 1 addition & 1 deletion sampler_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ pymultinest
kombine
ultranest>=3.0.0
dnest4
nessai>=0.2.3
nessai>=0.7.0
schwimmbad
zeus-mcmc>=2.3.0
25 changes: 14 additions & 11 deletions test/core/sampler/nessai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def setUp(self):
plot=False,
skip_import_verification=True,
sampling_seed=150914,
npool=None, # TODO: remove when support for nessai<0.7.0 is dropped
)
self.expected = self.sampler.default_kwargs
self.expected["n_pool"] = 1 # Because npool=1 by default
self.expected['output'] = 'outdir/label_nessai/'
self.expected['seed'] = 150914

Expand All @@ -48,28 +48,31 @@ def test_translate_kwargs_nlive(self):

def test_translate_kwargs_npool(self):
expected = self.expected.copy()
expected["n_pool"] = None
expected["n_pool"] = 2
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["n_pool"]
new_kwargs[equiv] = None
new_kwargs[equiv] = 2
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)

def test_translate_kwargs_seed(self):
assert self.expected["seed"] == 150914
def test_split_kwargs(self):
kwargs, run_kwargs = self.sampler.split_kwargs()
assert "save" not in run_kwargs
assert "plot" in run_kwargs

def test_npool_max_threads(self):
# TODO: remove when support for nessai<0.7.0 is dropped
def test_translate_kwargs_no_npool(self):
expected = self.expected.copy()
expected["n_pool"] = None
expected["max_threads"] = 1
expected["n_pool"] = 3
new_kwargs = self.sampler.kwargs.copy()
new_kwargs["n_pool"] = 1
new_kwargs["max_threads"] = 1
del new_kwargs["n_pool"]
self.sampler._npool = 3
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)

def test_translate_kwargs_seed(self):
assert self.expected["seed"] == 150914

@patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
def test_update_from_config_file(self):
expected = self.expected.copy()
Expand Down
10 changes: 2 additions & 8 deletions test/integration/sampler_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
nessai=dict(
nlive=100,
poolsize=1000,
max_iteration=1000,
max_threads=3,
poolsize=100,
max_iteration=500,
),
nestle=dict(nlive=100),
ptemcee=dict(
Expand Down Expand Up @@ -159,11 +158,6 @@ def _run_with_signal_handling(self, sampler, pool_size=1):
pytest.skip(f"{sampler} cannot be parallelized")
if sys.version_info.minor == 8 and sampler.lower == "cpnest":
pytest.skip("Pool interrupting broken for cpnest with py3.8")
if sampler.lower() == "nessai" and pool_size > 1:
pytest.skip(
"Interrupting with a pool is failing in pytest. "
"Likely due to interactions with the signal handling in nessai."
)
pid = os.getpid()
print(sampler)

Expand Down

0 comments on commit 2d1c9e9

Please sign in to comment.