Skip to content

Commit

Permalink
Merge pull request #355 from mj-will/add-checkpointing-callbacks
Browse files Browse the repository at this point in the history
Add checkpointing callbacks
  • Loading branch information
mj-will authored Nov 21, 2023
2 parents 5d93ca2 + c96ecb9 commit d5cd2b7
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 14 deletions.
8 changes: 8 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,11 @@
autoapi_dirs = ["../nessai/"]
autoapi_add_toctree_entry = False
autoapi_options = ["members", "show-inheritance", "show-module-summary"]

# -- RST config --------------------------------------------------------------
# Inline python highlighting, base on https://stackoverflow.com/a/68931907
rst_prolog = """
.. role:: python(code)
:language: python
:class: highlight
"""
2 changes: 2 additions & 0 deletions docs/docutils.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[restructuredtext parser]
syntax_highlight = short
84 changes: 84 additions & 0 deletions docs/further-details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,90 @@ Using analytic priors
To use this setting, the user must re-define ``new_point`` when defining the model as described in :doc:`running the sampler<running-the-sampler>`. This method must return samples as live points, see :ref:`using live points<Using live points>`. Once the method is redefined, set :code:`analytic_priors=True` when calling :py:class:`~nessai.flowsampler.FlowSampler`.


Checkpointing and resuming
==========================

Both the standard and importance nested samplers support checkpointing and
resuming. By default, the samplers periodically checkpoint to pickle file based
on the time elapsed since the last checkpoint. This behaviour can be configured
via various keyword arguments.


Configuration
-------------

The following options are available in all the sampler classes:

* :python:`checkpointing: bool`: Boolean to toggle checkpointing. If false, the sampler will not periodically checkpoint but will checkpoint at the end of sampling.
* :python:`checkpoint_on_iteration: bool`: Boolean to enable checkpointing based on the number of iterations rather than the elapsed time.
* :python:`checkpoint_interval: int`: The interval between checkpointing, the units depend on the value of :python:`checkpoint_interval`; if it false, is value the interval is specified in seconds; if it is true, the interval is specified in iterations.
* :python:`checkpoint_callback: Callable`: Callback function to be used instead of the default function. See `Checkpoint callbacks`_ for more details.

The following options are available when creating an instance of
:py:class:`~nessai.flowsampler.FlowSampler`:

* :python:`resume: bool`: Boolean to entirely enable or disable resuming irrespective of if there is a file or data to resume from.
* :python:`resume_file: str`: Name of the resume file.
* :python:`resume_data: Any`: Data to resume the sampler from instead of a resume file. The data will be passed to the :python:`resume_from_pickled_sampler` of the relevant class.


Resuming a sampling run
-----------------------

A sampling run can be resumed from either an existing resume file, which is
loaded automatically, or by specifying pickled data to resume from.
We recommended using the resume files, which are produced automatically, for
most applications.

The recommended method for resuming a run is by calling :py:class:`~nessai.flowsampler.FlowSampler` with
the same arguments that were originally used to start run; ensuring
:python:`resume=True` and :python:`resume_file` matches the name of the
:code:`.pkl` file in the output directory (the default is
:code:`nested_sampler_resume.pkl`).

.. note::

Depending on how the sampling was interrupted, some progress may be lost and
the sampling may resume from an earlier iteration.

Alternatively, you can specify the :python:`resume_data` argument which takes
priority over the resume file.
This will be passed to the :python:`resume_from_pickled_sampler` of the
corresponding sampler class.


Checkpoint callbacks
--------------------

Checkpoint callbacks allow the user to specify a custom function to use for
checkpointing the sampler.
This allows, for example, for the sampler to checkpoint an existing file rather.

The checkpoint callback function will be called in the :code:`checkpoint` method
with the class instance as the only argument, i.e.
:python:`checkpoint_callback(self)`.

All the sampler classes define custom :py:meth:`~nessai.samplers.base.BaseNestedSampler.__getstate__` methods that are
compatible with pickle and can be used to obtain a pickled representation of
the state of the sampler. Below is an example of a valid callback

.. code-block:: python
import pickle
filename = "checkpoint.pkl"
def checkpoint_callback(state):
with open(filename, "wb") as f:
pickle.dump(state, f)
This could then passed as a keyword argument when running or resuming a sampler
via :py:class:`~nessai.flowsampler.FlowSampler`.

.. warning::
The checkpoint callback is not included in the output of :python:`__getstate__`
and must be specified when resuming the sampler via :py:class:`~nessai.flowsampler.FlowSampler`.


Detailed explanation of outputs
===============================

Expand Down
11 changes: 9 additions & 2 deletions nessai/flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class FlowSampler:
signal_handling : bool
Enable or disable signal handling.
exit_code : int, optional
Exit code to use when forceably exiting the sampler.
Exit code to use when forcibly exiting the sampler.
close_pool : bool
If True, the multiprocessing pool will be closed once the run method
has been called. Disables the option in :code:`NestedSampler` if
Expand Down Expand Up @@ -167,6 +167,7 @@ def __init__(
model=model,
weights_path=weights_path,
flow_config=kwargs.get("flow_config"),
checkpoint_callback=kwargs.get("checkpoint_callback"),
)
else:
self.ns = self._resume_from_file(
Expand All @@ -175,6 +176,7 @@ def __init__(
resume_file=resume_file,
weights_path=weights_path,
flow_config=kwargs.get("flow_config"),
checkpoint_callback=kwargs.get("checkpoint_callback"),
)
else:
logger.debug("Not resuming sampler")
Expand All @@ -196,7 +198,7 @@ def __init__(
else:
logger.warning(
"Signal handling is disabled. nessai will not automatically "
"checkpoint when exitted."
"checkpoint when exited."
)

def check_resume(self, resume_file, resume_data):
Expand All @@ -218,6 +220,7 @@ def _resume_from_file(
model,
weights_path,
flow_config,
**kwargs,
):
logger.info(f"Trying to resume sampler from {resume_file}")
try:
Expand All @@ -226,6 +229,7 @@ def _resume_from_file(
model,
weights_path=weights_path,
flow_config=flow_config,
**kwargs,
)
except (FileNotFoundError, RuntimeError) as e:
logger.error(
Expand All @@ -239,6 +243,7 @@ def _resume_from_file(
model,
weights_path=weights_path,
flow_config=flow_config,
**kwargs,
)
except RuntimeError as e:
logger.error(
Expand All @@ -256,13 +261,15 @@ def _resume_from_data(
model,
weights_path,
flow_config,
**kwargs,
):
logger.info("Trying to resume sampler from `resume_data`")
return SamplerClass.resume_from_pickled_sampler(
resume_data,
model,
weights_path=weights_path,
flow_config=flow_config,
**kwargs,
)

@property
Expand Down
30 changes: 24 additions & 6 deletions nessai/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import pickle
import time
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

from glasflow import __version__ as glasflow_version
import numpy as np
Expand Down Expand Up @@ -41,6 +41,10 @@ class BaseNestedSampler(ABC):
checkpoint_on_iteration : bool
If true the checkpointing interval is checked against the number of
iterations
checkpoint_callback : Callback
Callback function to be used instead of the default function. The
function will be called in the code:`checkpoint` method as:
:code`checkpoint_callback(self)`.
logging_interval : int, optional
The interval in seconds used for periodic logging. If not specified,
then periodic logging is disabled.
Expand Down Expand Up @@ -69,6 +73,7 @@ def __init__(
checkpointing: bool = True,
checkpoint_interval: int = 600,
checkpoint_on_iteration: bool = False,
checkpoint_callback: Optional[Callable] = None,
logging_interval: int = None,
log_on_iteration: bool = True,
resume_file: str = None,
Expand All @@ -90,6 +95,7 @@ def __init__(
self.checkpointing = checkpointing
self.checkpoint_interval = checkpoint_interval
self.checkpoint_on_iteration = checkpoint_on_iteration
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_on_iteration:
self._last_checkpoint = 0
else:
Expand Down Expand Up @@ -268,13 +274,21 @@ def checkpoint(
return
self.sampling_time += now - self.sampling_start_time
logger.info("Checkpointing nested sampling")
safe_file_dump(
self, self.resume_file, pickle, save_existing=save_existing
)
if self.checkpoint_callback:
self.checkpoint_callback(self)
else:
safe_file_dump(
self, self.resume_file, pickle, save_existing=save_existing
)
self.sampling_start_time = datetime.datetime.now()

@classmethod
def resume_from_pickled_sampler(cls, sampler: Any, model: Model):
def resume_from_pickled_sampler(
cls,
sampler: Any,
model: Model,
checkpoint_callback: Optional[Callable] = None,
):
"""Resume from pickle data.
Parameters
Expand All @@ -283,6 +297,9 @@ def resume_from_pickled_sampler(cls, sampler: Any, model: Model):
Pickle data
model : :obj:`nessai.model.Model`
User-defined model
checkpoint_callback : Optional[Callable]
Checkpoint callback function. If not specified, the default method
will be used.
Returns
-------
Expand All @@ -297,6 +314,7 @@ def resume_from_pickled_sampler(cls, sampler: Any, model: Model):
)
sampler.model = model
sampler.resumed = True
sampler.checkpoint_callback = checkpoint_callback
return sampler

@classmethod
Expand Down Expand Up @@ -349,7 +367,7 @@ def get_result_dictionary(self):

def __getstate__(self):
d = self.__dict__
exclude = {"model", "proposal"}
exclude = {"model", "proposal", "checkpoint_callback"}
state = {k: d[k] for k in d.keys() - exclude}
state["_previous_likelihood_evaluations"] = d[
"model"
Expand Down
12 changes: 8 additions & 4 deletions nessai/samplers/importancesampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import datetime
import logging
import os
from typing import Any, List, Literal, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
checkpointing: bool = True,
checkpoint_interval: int = 600,
checkpoint_on_iteration: bool = False,
checkpoint_callback: Optional[Callable] = None,
save_existing_checkpoint: bool = False,
logging_interval: int = None,
log_on_iteration: bool = True,
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
checkpointing=checkpointing,
checkpoint_interval=checkpoint_interval,
checkpoint_on_iteration=checkpoint_on_iteration,
checkpoint_callback=checkpoint_callback,
logging_interval=logging_interval,
log_on_iteration=log_on_iteration,
resume_file=resume_file,
Expand Down Expand Up @@ -1900,7 +1902,7 @@ def get_result_dictionary(self):

@classmethod
def resume_from_pickled_sampler(
cls, sampler, model, flow_config=None, weights_path=None
cls, sampler, model, flow_config=None, weights_path=None, **kwargs
):
"""Resume from a pickled sampler.
Expand All @@ -1915,14 +1917,16 @@ def resume_from_pickled_sampler(
weights_path : Optional[dict]
Path to the weights files that will override the value stored in
the proposal.
kwargs :
Keyword arguments passed to the parent class's method.
Returns
-------
Instance of ImportanceNestedSampler
"""
cls.add_fields()
obj = super(ImportanceNestedSampler, cls).resume_from_pickled_sampler(
sampler, model
sampler, model, **kwargs
)
if flow_config is None:
flow_config = {}
Expand All @@ -1942,7 +1946,7 @@ def resume_from_pickled_sampler(

def __getstate__(self):
d = self.__dict__
exclude = {"model", "proposal", "log_q"}
exclude = {"model", "proposal", "log_q", "checkpoint_callback"}
state = {k: d[k] for k in d.keys() - exclude}
if d.get("model") is not None:
state["_previous_likelihood_evaluations"] = d[
Expand Down
8 changes: 6 additions & 2 deletions nessai/samplers/nestedsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
checkpoint_interval=600,
checkpoint_on_iteration=False,
checkpoint_on_training=False,
checkpoint_callback=None,
logging_interval=None,
log_on_iteration=True,
resume_file=None,
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
checkpointing=checkpointing,
checkpoint_interval=checkpoint_interval,
checkpoint_on_iteration=checkpoint_on_iteration,
checkpoint_callback=checkpoint_callback,
logging_interval=logging_interval,
log_on_iteration=log_on_iteration,
resume_file=resume_file,
Expand Down Expand Up @@ -1343,7 +1345,7 @@ def get_result_dictionary(self):

@classmethod
def resume_from_pickled_sampler(
cls, sampler, model, flow_config=None, weights_path=None
cls, sampler, model, flow_config=None, weights_path=None, **kwargs
):
"""Resume from a pickled sampler.
Expand All @@ -1358,13 +1360,15 @@ def resume_from_pickled_sampler(
weights_path : Optional[str]
Weights file to use in place of the weights file stored in the
pickle file.
kwargs :
Keyword arguments passed to the parent class's method.
Returns
-------
Instance of NestedSampler
"""
obj = super(NestedSampler, cls).resume_from_pickled_sampler(
sampler, model
sampler, model, **kwargs
)
obj._uninformed_proposal.resume(model)
if flow_config is None:
Expand Down
3 changes: 3 additions & 0 deletions tests/test_flowsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_resume_from_resume_data(flow_sampler, model, tmp_path):
model=model,
weights_path=None,
flow_config=None,
checkpoint_callback=None,
)


Expand All @@ -205,6 +206,7 @@ def test_resume_from_resume_file(flow_sampler, model, tmp_path):
model=model,
weights_path=None,
flow_config=None,
checkpoint_callback=None,
)


Expand Down Expand Up @@ -377,6 +379,7 @@ def test_init_resume(tmp_path, test_old, error):
integration_model,
flow_config=flow_config,
weights_path=weights_file,
checkpoint_callback=None,
)

assert fs.ns == "ns"
Expand Down
Loading

0 comments on commit d5cd2b7

Please sign in to comment.