From 7e6926f4719a331d06a66bc26be41d77779c3029 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Mon, 14 Feb 2022 18:40:56 -0500 Subject: [PATCH] More streamlined context caches usage (#547) * Removing context_cache kwargs. Users have to use instance attributes for context caches. * Adding context cache attributes and examples in docstring. --- docs/releasehistory.rst | 7 +- openmmtools/multistate/multistatesampler.py | 145 +++++++++----------- openmmtools/multistate/paralleltempering.py | 2 +- openmmtools/tests/test_sampling.py | 43 ++++++ 4 files changed, 117 insertions(+), 80 deletions(-) diff --git a/docs/releasehistory.rst b/docs/releasehistory.rst index 791bd960..353c517a 100644 --- a/docs/releasehistory.rst +++ b/docs/releasehistory.rst @@ -1,9 +1,14 @@ Release History *************** -0.2Y.Z - XXXXXX release +0.21.1 - Bugfix release ======================= +Bugfixes +-------- +- More streamlined context cache usage using instance attributes (`#547 `_). +- Improved docstring and examples for ``MultiStateSampler`` object. + 0.21.0 - Bugfix release ======================= diff --git a/openmmtools/multistate/multistatesampler.py b/openmmtools/multistate/multistatesampler.py index 15f83bae..37d8588c 100644 --- a/openmmtools/multistate/multistatesampler.py +++ b/openmmtools/multistate/multistatesampler.py @@ -126,15 +126,56 @@ class MultiStateSampler(object): sampler_states metadata is_completed - - :param number_of_iterations: Maximum number of integer iterations that will be run - - :param online_analysis_interval: How frequently to carry out online analysis in number of iterations - - :param online_analysis_target_error: Target free energy difference error float at which simulation will be stopped during online analysis, in dimensionless energy - - :param online_analysis_minimum_iterations: Minimum number of iterations needed before online analysis is run as int - + energy_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache + Context cache to be used for energy computations. Defaults to using global context cache. + sampler_context_cache : openmmtools.cache.ContextCache, default=openmmtools.cache.global_context_cache + Context cache to be used for propagation. Defaults to using global context cache. + + Examples + -------- + Sampling multiple states of an alanine dipeptide in implicit solvent system. + + >>> import math + >>> import tempfile + >>> from openmm import unit + >>> from openmmtools import testsystems, states, mcmc + >>> from openmmtools.multistate import MultiStateSampler, MultiStateReporter + >>> testsystem = testsystems.AlanineDipeptideImplicit() + + Create thermodynamic states + + >>> n_replicas = 3 + >>> T_min = 298.0 * unit.kelvin # Minimum temperature. + >>> T_max = 600.0 * unit.kelvin # Maximum temperature. + >>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0) + ... for i in range(n_replicas)] + >>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0) + ... for i in range(n_replicas)] + >>> thermodynamic_states = [states.ThermodynamicState(system=testsystem.system, temperature=T) + ... for T in temperatures] + + Initialize simulation object with options. Run with a GHMC integrator. + + >>> move = mcmc.GHMCMove(timestep=2.0*unit.femtoseconds, n_steps=50) + >>> simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2) + + Create simulation and store output in temporary file + + >>> storage_path = tempfile.NamedTemporaryFile(delete=False).name + '.nc' + >>> reporter = MultiStateReporter(storage_path, checkpoint_interval=1) + >>> simulation.create(thermodynamic_states=thermodynamic_states, + ... sampler_states=states.SamplerState(testsystem.positions), storage=reporter) + + Optionally, specify unlimited context cache attributes using the fastest mixed precision platform + + >>> from openmmtools.cache import ContextCache + >>> from openmmtools.utils import get_fastest_platform + >>> platform = get_fastest_platform(minimum_precision='mixed') + >>> simulation.energy_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform) + >>> simulation.sampler_context_cache = ContextCache(capacity=None, time_to_live=None, platform=platform) + + Run the simulation + >>> simulation.run() """ # ------------------------------------------------------------------------- @@ -194,8 +235,8 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1, self._have_displayed_citations_before = False - # Initializing context cache attributes to global cache - self.energy_context_cache, self.sampler_context_cache = cache.global_context_cache, cache.global_context_cache + # Initializing context cache attributes + self._initialize_context_caches() # Check convergence. if self.number_of_iterations == np.inf: @@ -207,7 +248,7 @@ def __init__(self, mcmc_moves=None, number_of_iterations=1, "specified maximum number of iterations!") @classmethod - def from_storage(cls, storage, energy_context_cache=None, propagation_context_cache=None): + def from_storage(cls, storage): """Constructor from an existing storage file. Parameters @@ -217,12 +258,6 @@ def from_storage(cls, storage, energy_context_cache=None, propagation_context_ca If :class:`Reporter`: uses the :class:`Reporter` options In the future this will be able to take a Storage class as well. - energy_context_cache : openmmtools.cache.ContextCache or None, optional, default None - Context cache to be used for energy computations. If None, a new fresh cache will be used. - - propagation_context_cache : openmmtools.cache.ContextCache or None, optional, default None - Context cache to be used for move/integrator propagation. If None, a new fresh cache will be used. - Returns ------- sampler : MultiStateSampler @@ -237,9 +272,7 @@ def from_storage(cls, storage, energy_context_cache=None, propagation_context_ca # Open the reporter to read the data. reporter.open(mode='r') sampler = cls._instantiate_sampler_from_reporter(reporter) - sampler._restore_sampler_from_reporter(reporter, - energy_context_cache=energy_context_cache, - propagation_context_cache=propagation_context_cache) + sampler._restore_sampler_from_reporter(reporter) finally: # Close reporter in reading mode. reporter.close() @@ -247,7 +280,7 @@ def from_storage(cls, storage, energy_context_cache=None, propagation_context_ca # We open the reporter only in node 0 in append mode ready for use sampler._reporter = reporter mpiplus.run_single_node(0, sampler._reporter.open, mode='a', - broadcast_result=False, sync_nodes=False) + broadcast_result=False, sync_nodes=False) # Don't write the new last iteration, we have not technically # written anything yet, so there is no "junk". return sampler @@ -491,7 +524,7 @@ def is_completed(self): def create(self, thermodynamic_states: list, sampler_states, storage, initial_thermodynamic_states=None, unsampled_thermodynamic_states=None, - metadata=None, energy_context_cache=None, sampler_context_cache=None): + metadata=None): """Create new multistate sampler simulation. Parameters @@ -532,10 +565,6 @@ def create(self, thermodynamic_states: list, sampler_states, storage, is None). metadata : dict, optional, default=None Simulation metadata to be stored in the file. - energy_context_cache : openmmtools.cache.ContextCache or None, optional, default None - Context cache to be used for energy computations. If None, global context cache will be used. - sampler_context_cache : openmmtools.cache.ContextCache or None, optional, default None - Context cache to be used for move/integrator propagation. If None, global context cache will be used. """ # Handle case in which storage is a string and not a Reporter object. self._reporter = self._reporter_from_storage(storage, check_exist=False) @@ -556,9 +585,7 @@ def create(self, thermodynamic_states: list, sampler_states, storage, self._pre_write_create(thermodynamic_states, sampler_states, storage, initial_thermodynamic_states=initial_thermodynamic_states, unsampled_thermodynamic_states=unsampled_thermodynamic_states, - metadata=metadata, - energy_context_cache=energy_context_cache, - sampler_context_cache=sampler_context_cache) + metadata=metadata) # Display papers to be cited. self._display_citations() @@ -769,9 +796,7 @@ def _pre_write_create(self, storage, initial_thermodynamic_states=None, unsampled_thermodynamic_states=None, - metadata=None, - energy_context_cache=None, - sampler_context_cache=None): + metadata=None,): """ Internal function which allocates and sets up ALL variables prior to actually using them. This is helpful to ensure subclasses have all variables created prior to writing them out with @@ -808,13 +833,6 @@ def _pre_write_create(self, metadata['title'] = default_title self._metadata = metadata - # Handling context cache parameters and attributes - # update context caches attributes handling inputs - self.energy_context_cache, self.sampler_context_cache = self._initialize_context_caches( - energy_context_cache, - sampler_context_cache - ) - # Save thermodynamic states. This sets n_replicas. self._thermodynamic_states = copy.deepcopy(thermodynamic_states) @@ -892,7 +910,7 @@ def _instantiate_sampler_from_reporter(cls, reporter): sampler._display_citations() return sampler - def _restore_sampler_from_reporter(self, reporter, energy_context_cache=None, propagation_context_cache=None): + def _restore_sampler_from_reporter(self, reporter): """ (Re-)initialize the instanced sampler from the reporter. Intended to be called as the second half of a :func:`from_storage` method after the :class:`MultiStateSampler` has been instanced from disk. @@ -978,12 +996,8 @@ def _read_options(check_iteration): self._last_mbar_f_k = last_mbar_f_k self._last_err_free_energy = last_err_free_energy - # Handle with context caches as specified - # update context caches attributes handling inputs - self.energy_context_cache, self.sampler_context_cache = self._initialize_context_caches( - energy_context_cache, - propagation_context_cache - ) + # Initialize context caches + self._initialize_context_caches() def _check_nan_energy(self): """Checks that energies are finite and abort otherwise. @@ -1652,41 +1666,16 @@ def __init__(self, error_message): raise RestorationError(message) - @staticmethod - def _initialize_context_caches(energy_context_cache=None, propagation_context_cache=None): + def _initialize_context_caches(self): """Handle energy and propagation context cache default behavior. - .. note:: As of 03-Feb-22 default behavior is to use the global cache. - - Parameters - ---------- - energy_context_cache : openmmtools.cache.ContextCache or None - Context cache to be used in energy computations. If None, - it will use the global context cache. - propagation_context_cache : openmmtools.cache.ContextCache or None - Context cache to be used in the propagation of the mcmc moves. If None, - it will use the global context cache. + Centralized API point where to initialize the context cache instance attributes. - Returns - ------- - energy_context_cache : openmmtools.cache.ContextCache - Context cache to be used in energy computations. - propagation_context_cache : openmmtools.cache.ContextCache - Context cache to be used in the propagation of the mcmc moves. + .. note:: As of 03-Feb-22 default behavior is to use the global cache. """ - # Handling energy context cache - if energy_context_cache is None: - # Default behavior, global context cache - energy_context_cache = cache.global_context_cache - elif not isinstance(energy_context_cache, cache.ContextCache): - raise ValueError("Energy context cache input is not a valid ContextCache or None type.") - # Handling propagation context cache - if propagation_context_cache is None: - # Default behavior, global context cache - propagation_context_cache = cache.global_context_cache - elif not isinstance(propagation_context_cache, cache.ContextCache): - raise ValueError("MCMC move context cache input is not a valid ContextCache or None type.") - return energy_context_cache, propagation_context_cache + # Default is using global context cache + self.energy_context_cache = cache.global_context_cache + self.sampler_context_cache = cache.global_context_cache # ------------------------------------------------------------------------- # Internal-usage: Test globals diff --git a/openmmtools/multistate/paralleltempering.py b/openmmtools/multistate/paralleltempering.py index 344219dd..890e9dd8 100644 --- a/openmmtools/multistate/paralleltempering.py +++ b/openmmtools/multistate/paralleltempering.py @@ -197,7 +197,7 @@ def _compute_replica_energies(self, replica_id): reference_thermodynamic_state = self._thermodynamic_states[0] # Get the context, any Integrator works. - context, integrator = cache.global_context_cache.get_context(reference_thermodynamic_state) + context, integrator = self.energy_context_cache.get_context(reference_thermodynamic_state) # Update positions and box vectors. sampler_state.apply_to_context(context) diff --git a/openmmtools/tests/test_sampling.py b/openmmtools/tests/test_sampling.py index 1962d7f9..821b739b 100644 --- a/openmmtools/tests/test_sampling.py +++ b/openmmtools/tests/test_sampling.py @@ -35,6 +35,7 @@ import mpiplus import openmmtools as mmtools +from openmmtools import cache from openmmtools import testsystems from openmmtools.multistate import MultiStateReporter from openmmtools.multistate import MultiStateSampler, MultiStateSamplerAnalyzer @@ -1442,6 +1443,48 @@ def test_online_analysis_stops(self): assert sampler._iteration < n_iterations assert sampler.is_completed + def test_context_cache_default(self): + """Test default behavior of context cache attributes.""" + sampler = self.SAMPLER() + global_context_cache = cache.global_context_cache + # Default is to use global context cache for both context cache attributes + assert sampler.sampler_context_cache is global_context_cache + assert sampler.energy_context_cache is global_context_cache + + def test_context_cache_energy_propagation(self): + """Test specifying different context caches for energy and propagation in a short simulation.""" + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + n_replicas = len(sampler_states) + if n_replicas == 1: + # This test is intended for use with more than one replica + return + + with self.temporary_storage_path() as storage_path: + # Create a replica exchange that propagates only 1 femtosecond + # per iteration so that positions won't change much. + move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) + sampler = self.SAMPLER(mcmc_moves=move) + reporter = self.REPORTER(storage_path) + self.call_sampler_create(sampler, reporter, + thermodynamic_states, sampler_states, + unsampled_states) + # Set context cache attributes + sampler.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None) + sampler.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None) + # Compute energies + sampler._compute_energies() + # Check only energy context cache has been accessed + assert sampler.energy_context_cache._lru._n_access > 0, \ + f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." + assert sampler.sampler_context_cache._lru._n_access == 0, \ + f"{sampler.sampler_context_cache._lru._n_access} accesses, expected 0." + + # Propagate replicas + sampler._propagate_replicas() + # Check propagation context cache has been accessed after propagation + assert sampler.sampler_context_cache._lru._n_access > 0, \ + f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." + #############