diff --git a/docs/conf.py b/docs/conf.py index d25279ee..0a27a34d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -67,3 +67,11 @@ autoapi_dirs = ["../nessai/"] autoapi_add_toctree_entry = False autoapi_options = ["members", "show-inheritance", "show-module-summary"] + +# -- RST config -------------------------------------------------------------- +# Inline python highlighting, base on https://stackoverflow.com/a/68931907 +rst_prolog = """ +.. role:: python(code) + :language: python + :class: highlight +""" diff --git a/docs/docutils.conf b/docs/docutils.conf new file mode 100644 index 00000000..1bf4d832 --- /dev/null +++ b/docs/docutils.conf @@ -0,0 +1,2 @@ +[restructuredtext parser] +syntax_highlight = short diff --git a/docs/further-details.rst b/docs/further-details.rst index 4f2fc6a6..ce56b395 100644 --- a/docs/further-details.rst +++ b/docs/further-details.rst @@ -61,6 +61,90 @@ Using analytic priors To use this setting, the user must re-define ``new_point`` when defining the model as described in :doc:`running the sampler`. This method must return samples as live points, see :ref:`using live points`. Once the method is redefined, set :code:`analytic_priors=True` when calling :py:class:`~nessai.flowsampler.FlowSampler`. +Checkpointing and resuming +========================== + +Both the standard and importance nested samplers support checkpointing and +resuming. By default, the samplers periodically checkpoint to pickle file based +on the time elapsed since the last checkpoint. This behaviour can be configured +via various keyword arguments. + + +Configuration +------------- + +The following options are available in all the sampler classes: + +* :python:`checkpointing: bool`: Boolean to toggle checkpointing. If false, the sampler will not periodically checkpoint but will checkpoint at the end of sampling. +* :python:`checkpoint_on_iteration: bool`: Boolean to enable checkpointing based on the number of iterations rather than the elapsed time. +* :python:`checkpoint_interval: int`: The interval between checkpointing, the units depend on the value of :python:`checkpoint_interval`; if it false, is value the interval is specified in seconds; if it is true, the interval is specified in iterations. +* :python:`checkpoint_callback: Callable`: Callback function to be used instead of the default function. See `Checkpoint callbacks`_ for more details. + +The following options are available when creating an instance of +:py:class:`~nessai.flowsampler.FlowSampler`: + +* :python:`resume: bool`: Boolean to entirely enable or disable resuming irrespective of if there is a file or data to resume from. +* :python:`resume_file: str`: Name of the resume file. +* :python:`resume_data: Any`: Data to resume the sampler from instead of a resume file. The data will be passed to the :python:`resume_from_pickled_sampler` of the relevant class. + + +Resuming a sampling run +----------------------- + +A sampling run can be resumed from either an existing resume file, which is +loaded automatically, or by specifying pickled data to resume from. +We recommended using the resume files, which are produced automatically, for +most applications. + +The recommended method for resuming a run is by calling :py:class:`~nessai.flowsampler.FlowSampler` with +the same arguments that were originally used to start run; ensuring +:python:`resume=True` and :python:`resume_file` matches the name of the +:code:`.pkl` file in the output directory (the default is +:code:`nested_sampler_resume.pkl`). + +.. note:: + + Depending on how the sampling was interrupted, some progress may be lost and + the sampling may resume from an earlier iteration. + +Alternatively, you can specify the :python:`resume_data` argument which takes +priority over the resume file. +This will be passed to the :python:`resume_from_pickled_sampler` of the +corresponding sampler class. + + +Checkpoint callbacks +-------------------- + +Checkpoint callbacks allow the user to specify a custom function to use for +checkpointing the sampler. +This allows, for example, for the sampler to checkpoint an existing file rather. + +The checkpoint callback function will be called in the :code:`checkpoint` method +with the class instance as the only argument, i.e. +:python:`checkpoint_callback(self)`. + +All the sampler classes define custom :py:meth:`~nessai.samplers.base.BaseNestedSampler.__getstate__` methods that are +compatible with pickle and can be used to obtain a pickled representation of +the state of the sampler. Below is an example of a valid callback + +.. code-block:: python + + import pickle + filename = "checkpoint.pkl" + + def checkpoint_callback(state): + with open(filename, "wb") as f: + pickle.dump(state, f) + +This could then passed as a keyword argument when running or resuming a sampler +via :py:class:`~nessai.flowsampler.FlowSampler`. + +.. warning:: + The checkpoint callback is not included in the output of :python:`__getstate__` + and must be specified when resuming the sampler via :py:class:`~nessai.flowsampler.FlowSampler`. + + Detailed explanation of outputs =============================== diff --git a/nessai/flowsampler.py b/nessai/flowsampler.py index 9156a9bf..fc01d921 100644 --- a/nessai/flowsampler.py +++ b/nessai/flowsampler.py @@ -53,7 +53,7 @@ class FlowSampler: signal_handling : bool Enable or disable signal handling. exit_code : int, optional - Exit code to use when forceably exiting the sampler. + Exit code to use when forcibly exiting the sampler. close_pool : bool If True, the multiprocessing pool will be closed once the run method has been called. Disables the option in :code:`NestedSampler` if @@ -167,6 +167,7 @@ def __init__( model=model, weights_path=weights_path, flow_config=kwargs.get("flow_config"), + checkpoint_callback=kwargs.get("checkpoint_callback"), ) else: self.ns = self._resume_from_file( @@ -175,6 +176,7 @@ def __init__( resume_file=resume_file, weights_path=weights_path, flow_config=kwargs.get("flow_config"), + checkpoint_callback=kwargs.get("checkpoint_callback"), ) else: logger.debug("Not resuming sampler") @@ -196,7 +198,7 @@ def __init__( else: logger.warning( "Signal handling is disabled. nessai will not automatically " - "checkpoint when exitted." + "checkpoint when exited." ) def check_resume(self, resume_file, resume_data): @@ -218,6 +220,7 @@ def _resume_from_file( model, weights_path, flow_config, + **kwargs, ): logger.info(f"Trying to resume sampler from {resume_file}") try: @@ -226,6 +229,7 @@ def _resume_from_file( model, weights_path=weights_path, flow_config=flow_config, + **kwargs, ) except (FileNotFoundError, RuntimeError) as e: logger.error( @@ -239,6 +243,7 @@ def _resume_from_file( model, weights_path=weights_path, flow_config=flow_config, + **kwargs, ) except RuntimeError as e: logger.error( @@ -256,6 +261,7 @@ def _resume_from_data( model, weights_path, flow_config, + **kwargs, ): logger.info("Trying to resume sampler from `resume_data`") return SamplerClass.resume_from_pickled_sampler( @@ -263,6 +269,7 @@ def _resume_from_data( model, weights_path=weights_path, flow_config=flow_config, + **kwargs, ) @property diff --git a/nessai/samplers/base.py b/nessai/samplers/base.py index b6d7cb6c..cfbd0b32 100644 --- a/nessai/samplers/base.py +++ b/nessai/samplers/base.py @@ -6,7 +6,7 @@ import os import pickle import time -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from glasflow import __version__ as glasflow_version import numpy as np @@ -41,6 +41,10 @@ class BaseNestedSampler(ABC): checkpoint_on_iteration : bool If true the checkpointing interval is checked against the number of iterations + checkpoint_callback : Callback + Callback function to be used instead of the default function. The + function will be called in the code:`checkpoint` method as: + :code`checkpoint_callback(self)`. logging_interval : int, optional The interval in seconds used for periodic logging. If not specified, then periodic logging is disabled. @@ -69,6 +73,7 @@ def __init__( checkpointing: bool = True, checkpoint_interval: int = 600, checkpoint_on_iteration: bool = False, + checkpoint_callback: Optional[Callable] = None, logging_interval: int = None, log_on_iteration: bool = True, resume_file: str = None, @@ -90,6 +95,7 @@ def __init__( self.checkpointing = checkpointing self.checkpoint_interval = checkpoint_interval self.checkpoint_on_iteration = checkpoint_on_iteration + self.checkpoint_callback = checkpoint_callback if self.checkpoint_on_iteration: self._last_checkpoint = 0 else: @@ -268,13 +274,21 @@ def checkpoint( return self.sampling_time += now - self.sampling_start_time logger.info("Checkpointing nested sampling") - safe_file_dump( - self, self.resume_file, pickle, save_existing=save_existing - ) + if self.checkpoint_callback: + self.checkpoint_callback(self) + else: + safe_file_dump( + self, self.resume_file, pickle, save_existing=save_existing + ) self.sampling_start_time = datetime.datetime.now() @classmethod - def resume_from_pickled_sampler(cls, sampler: Any, model: Model): + def resume_from_pickled_sampler( + cls, + sampler: Any, + model: Model, + checkpoint_callback: Optional[Callable] = None, + ): """Resume from pickle data. Parameters @@ -283,6 +297,9 @@ def resume_from_pickled_sampler(cls, sampler: Any, model: Model): Pickle data model : :obj:`nessai.model.Model` User-defined model + checkpoint_callback : Optional[Callable] + Checkpoint callback function. If not specified, the default method + will be used. Returns ------- @@ -297,6 +314,7 @@ def resume_from_pickled_sampler(cls, sampler: Any, model: Model): ) sampler.model = model sampler.resumed = True + sampler.checkpoint_callback = checkpoint_callback return sampler @classmethod @@ -349,7 +367,7 @@ def get_result_dictionary(self): def __getstate__(self): d = self.__dict__ - exclude = {"model", "proposal"} + exclude = {"model", "proposal", "checkpoint_callback"} state = {k: d[k] for k in d.keys() - exclude} state["_previous_likelihood_evaluations"] = d[ "model" diff --git a/nessai/samplers/importancesampler.py b/nessai/samplers/importancesampler.py index 4d83f8ee..d67d9356 100644 --- a/nessai/samplers/importancesampler.py +++ b/nessai/samplers/importancesampler.py @@ -5,7 +5,7 @@ import datetime import logging import os -from typing import Any, List, Literal, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import matplotlib import matplotlib.pyplot as plt @@ -101,6 +101,7 @@ def __init__( checkpointing: bool = True, checkpoint_interval: int = 600, checkpoint_on_iteration: bool = False, + checkpoint_callback: Optional[Callable] = None, save_existing_checkpoint: bool = False, logging_interval: int = None, log_on_iteration: bool = True, @@ -146,6 +147,7 @@ def __init__( checkpointing=checkpointing, checkpoint_interval=checkpoint_interval, checkpoint_on_iteration=checkpoint_on_iteration, + checkpoint_callback=checkpoint_callback, logging_interval=logging_interval, log_on_iteration=log_on_iteration, resume_file=resume_file, @@ -1900,7 +1902,7 @@ def get_result_dictionary(self): @classmethod def resume_from_pickled_sampler( - cls, sampler, model, flow_config=None, weights_path=None + cls, sampler, model, flow_config=None, weights_path=None, **kwargs ): """Resume from a pickled sampler. @@ -1915,6 +1917,8 @@ def resume_from_pickled_sampler( weights_path : Optional[dict] Path to the weights files that will override the value stored in the proposal. + kwargs : + Keyword arguments passed to the parent class's method. Returns ------- @@ -1922,7 +1926,7 @@ def resume_from_pickled_sampler( """ cls.add_fields() obj = super(ImportanceNestedSampler, cls).resume_from_pickled_sampler( - sampler, model + sampler, model, **kwargs ) if flow_config is None: flow_config = {} @@ -1942,7 +1946,7 @@ def resume_from_pickled_sampler( def __getstate__(self): d = self.__dict__ - exclude = {"model", "proposal", "log_q"} + exclude = {"model", "proposal", "log_q", "checkpoint_callback"} state = {k: d[k] for k in d.keys() - exclude} if d.get("model") is not None: state["_previous_likelihood_evaluations"] = d[ diff --git a/nessai/samplers/nestedsampler.py b/nessai/samplers/nestedsampler.py index 977be399..9a3397be 100644 --- a/nessai/samplers/nestedsampler.py +++ b/nessai/samplers/nestedsampler.py @@ -162,6 +162,7 @@ def __init__( checkpoint_interval=600, checkpoint_on_iteration=False, checkpoint_on_training=False, + checkpoint_callback=None, logging_interval=None, log_on_iteration=True, resume_file=None, @@ -202,6 +203,7 @@ def __init__( checkpointing=checkpointing, checkpoint_interval=checkpoint_interval, checkpoint_on_iteration=checkpoint_on_iteration, + checkpoint_callback=checkpoint_callback, logging_interval=logging_interval, log_on_iteration=log_on_iteration, resume_file=resume_file, @@ -1343,7 +1345,7 @@ def get_result_dictionary(self): @classmethod def resume_from_pickled_sampler( - cls, sampler, model, flow_config=None, weights_path=None + cls, sampler, model, flow_config=None, weights_path=None, **kwargs ): """Resume from a pickled sampler. @@ -1358,13 +1360,15 @@ def resume_from_pickled_sampler( weights_path : Optional[str] Weights file to use in place of the weights file stored in the pickle file. + kwargs : + Keyword arguments passed to the parent class's method. Returns ------- Instance of NestedSampler """ obj = super(NestedSampler, cls).resume_from_pickled_sampler( - sampler, model + sampler, model, **kwargs ) obj._uninformed_proposal.resume(model) if flow_config is None: diff --git a/tests/test_flowsampler.py b/tests/test_flowsampler.py index c14dc4fa..cf408ecd 100644 --- a/tests/test_flowsampler.py +++ b/tests/test_flowsampler.py @@ -182,6 +182,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path): model=model, weights_path=None, flow_config=None, + checkpoint_callback=None, ) @@ -205,6 +206,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path): model=model, weights_path=None, flow_config=None, + checkpoint_callback=None, ) @@ -377,6 +379,7 @@ def test_init_resume(tmp_path, test_old, error): integration_model, flow_config=flow_config, weights_path=weights_file, + checkpoint_callback=None, ) assert fs.ns == "ns" diff --git a/tests/test_samplers/test_base_sampler.py b/tests/test_samplers/test_base_sampler.py index 0e5729ac..51b62c1a 100644 --- a/tests/test_samplers/test_base_sampler.py +++ b/tests/test_samplers/test_base_sampler.py @@ -273,6 +273,7 @@ def test_checkpoint_iteration(sampler, wait, periodic): sampler.checkpoint_iterations = [10] sampler.checkpoint_on_iteration = True sampler.checkpoint_interval = 10 + sampler.checkpoint_callback = None sampler._last_checkpoint = 0 sampler.iteration = 20 now = datetime.datetime.now() @@ -307,6 +308,7 @@ def test_checkpoint_time(sampler, wait): sampler.checkpoint_iterations = [10] sampler.checkpoint_on_iteration = False sampler.checkpoint_interval = 15 * 60 + sampler.checkpoint_callback = None sampler.sampling_start_time = now - datetime.timedelta(minutes=32) sampler._last_checkpoint = now - datetime.timedelta(minutes=16) sampler.iteration = 20 @@ -335,6 +337,7 @@ def test_checkpoint_periodic_skipped_iteration(sampler): sampler.iteration = 10 sampler._last_checkpoint = 9 sampler.checkpoint_interval = 10 + sampler.checkpoint_callback = None with patch("nessai.samplers.base.safe_file_dump") as sfd_mock: BaseNestedSampler.checkpoint(sampler, periodic=True) sfd_mock.assert_not_called() @@ -346,6 +349,7 @@ def test_checkpoint_periodic_skipped_time(sampler): sampler.iteration = 10 sampler._last_checkpoint = datetime.datetime.now() sampler.checkpoint_interval = 600 + sampler.checkpoint_callback = None with patch("nessai.samplers.base.safe_file_dump") as sfd_mock: BaseNestedSampler.checkpoint(sampler, periodic=True) sfd_mock.assert_not_called() @@ -357,6 +361,7 @@ def test_checkpoint_force(sampler): sampler.sampling_start_time = now - datetime.timedelta(minutes=32) sampler.sampling_time = datetime.timedelta() sampler.resume_file = "test.pkl" + sampler.checkpoint_callback = None with patch("nessai.samplers.base.safe_file_dump") as sfd_mock: BaseNestedSampler.checkpoint(sampler, periodic=True, force=True) sfd_mock.assert_called_once_with( @@ -364,6 +369,25 @@ def test_checkpoint_force(sampler): ) +def test_checkpoint_callback(sampler): + """Assert the checkpoint callback is used""" + + callback = MagicMock() + + sampler.checkpoint_iterations = [10] + sampler.checkpoint_on_iteration = True + sampler.checkpoint_interval = 10 + sampler.checkpoint_callback = callback + sampler._last_checkpoint = 0 + sampler.iteration = 20 + now = datetime.datetime.now() + sampler.sampling_start_time = now + sampler.sampling_time = datetime.timedelta() + + BaseNestedSampler.checkpoint(sampler) + callback.assert_called_once_with(sampler) + + def test_nested_sampling_loop(sampler): """Assert an error is raised""" with pytest.raises(NotImplementedError): diff --git a/tests/test_sampling/test_ins_sampling.py b/tests/test_sampling/test_ins_sampling.py index 3f1c0d1e..dba8e153 100644 --- a/tests/test_sampling/test_ins_sampling.py +++ b/tests/test_sampling/test_ins_sampling.py @@ -1,5 +1,6 @@ """Test sampling with the importance nested sampler""" import os +import pickle from nessai.flowsampler import FlowSampler import numpy as np @@ -42,3 +43,62 @@ def test_ins_resume(tmp_path, model, flow_config): assert fp.ns.max_iteration == 2 assert fp.ns.finalised is True np.testing.assert_array_almost_equal(new_log_q, original_log_q) + + +@pytest.mark.slow_integration_test +def test_ins_checkpoint_callback(tmp_path, model, flow_config): + output = tmp_path / "test_ins_checkpoint_callback" + + filename = os.path.join(output, "test.pkl") + resume_file = "resume.pkl" + + def checkpoint_callback(state): + with open(filename, "wb") as f: + pickle.dump(state, f) + + fs = FlowSampler( + model, + output=output, + resume=True, + nlive=500, + min_samples=50, + plot=False, + flow_config=flow_config, + checkpoint_on_iteration=True, + checkpoint_interval=1, + importance_nested_sampler=True, + max_iteration=2, + resume_file=resume_file, + checkpoint_callback=checkpoint_callback, + ) + fs.run() + assert fs.ns.iteration == 2 + assert os.path.exists(filename) + assert not os.path.exists(os.path.join(output, resume_file)) + + del fs + + with open(filename, "rb") as f: + resume_data = pickle.load(f) + + resume_data.test_variable = "abc" + + fs = FlowSampler( + model, + output=output, + resume=True, + nlive=500, + min_samples=50, + plot=False, + flow_config=flow_config, + checkpoint_on_iteration=True, + checkpoint_interval=1, + importance_nested_sampler=True, + max_iteration=2, + checkpoint_callback=checkpoint_callback, + resume_data=resume_data, + resume_file=resume_file, + ) + assert fs.ns.iteration == 2 + assert fs.ns.finalised is True + assert fs.ns.test_variable == "abc" diff --git a/tests/test_sampling/test_standard_sampling.py b/tests/test_sampling/test_standard_sampling.py index fef3e16f..0da1c3db 100644 --- a/tests/test_sampling/test_standard_sampling.py +++ b/tests/test_sampling/test_standard_sampling.py @@ -622,3 +622,53 @@ def test_sampling_result_extension(integration_model, tmp_path, extension): ) fs.run(plot=False) assert os.path.exists(os.path.join(output, f"result.{extension}")) + + +@pytest.mark.slow_integration_test +def test_sampling_with_checkpoint_callback(integration_model, tmp_path): + """Test the usage if the checkpoint callbacks""" + import pickle + + output = tmp_path / "test_callbacks" + output.mkdir() + + checkpoint_file = output / "test.pkl" + + def callback(state): + with open(checkpoint_file, "wb") as f: + pickle.dump(state, f) + + fs = FlowSampler( + integration_model, + output=output, + nlive=100, + plot=False, + proposal_plots=False, + checkpoint_callback=callback, + max_iteration=100, + checkpoint_on_iteration=True, + checkpoint_interval=50, + ) + fs.run(plot=False) + assert fs.ns.iteration == 100 + + del fs + + with open(checkpoint_file, "rb") as f: + resume_data = pickle.load(f) + + fs = FlowSampler( + integration_model, + output=output, + nlive=100, + plot=False, + proposal_plots=False, + checkpoint_callback=callback, + checkpoint_on_iteration=True, + checkpoint_interval=50, + resume_data=resume_data, + resume=True, + ) + fs.ns.max_iteration = 200 + fs.run() + assert fs.ns.iteration == 200