diff --git a/openmmtools/multistate/multistatereporter.py b/openmmtools/multistate/multistatereporter.py index 4ac1f5cb..93339f67 100644 --- a/openmmtools/multistate/multistatereporter.py +++ b/openmmtools/multistate/multistatereporter.py @@ -100,6 +100,12 @@ class MultiStateReporter(object): analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple) If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms. + position_interval : int, default 1 + the frequency at which to write positions relative to analysis + information, 0 would prevent information being written + velocity_interval : int, default 1 + the frequency at which to write positions relative to analysis + information, 0 would prevent information being written Attributes ---------- @@ -113,7 +119,10 @@ class MultiStateReporter(object): """ def __init__(self, storage, open_mode=None, checkpoint_interval=50, checkpoint_storage=None, - analysis_particle_indices=()): + analysis_particle_indices=(), + position_interval=1, + velocity_interval=1, + ): # Warn that API is experimental logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases') @@ -136,6 +145,9 @@ def __init__(self, storage, open_mode=None, self._checkpoint_interval = checkpoint_interval # Cast to tuple no mater what 1-D-like input was given self._analysis_particle_indices = tuple(analysis_particle_indices) + self._position_interval = position_interval + self._velocity_interval = velocity_interval + if open_mode is not None: self.open(open_mode) # TODO: Maybe we want to expose this flag to control ovrwriting/appending @@ -202,6 +214,16 @@ def checkpoint_interval(self): """Returns the checkpoint interval""" return self._checkpoint_interval + @property + def position_interval(self): + """Interval relative to energies that positions are written at""" + return self._position_interval + + @property + def velocity_interval(self): + """Interval relative to energies that velocities are written at""" + return self._velocity_interval + def storage_exists(self, skip_size=False): """ Check if the storage files exist on disk. @@ -415,6 +437,8 @@ def _initialize_storage_file(self, ncfile, nc_name, convention): ncfile.ConventionVersion = '0.2' ncfile.DataUsedFor = nc_name ncfile.CheckpointInterval = self._checkpoint_interval + ncfile.PositionInterval = self._position_interval + ncfile.VelocityInterval = self._velocity_interval # Create and initialize the global variables nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar') @@ -1647,35 +1671,47 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i write_iteration = self._calculate_checkpoint_iteration(iteration) else: write_iteration = iteration + + # write out pos/vel - if checkpointing, + # or if interval matches desired frequency + write_pos = (storage_file == 'checkpoint' or + (self._position_interval != 0 + and not (write_iteration % self._position_interval))) + write_vel = (storage_file == 'checkpoint' or + (self._velocity_interval != 0 + and not (write_iteration % self._velocity_interval))) + # Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval if write_iteration is not None: # Store sampler states. - # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments - positions = np.zeros([n_replicas, n_particles, 3]) - for replica_index, sampler_state in enumerate(sampler_states): - # Store positions in memory first - x = sampler_state.positions / unit.nanometers - positions[replica_index, :, :] = x[:, :] - # Store positions - storage.variables['positions'][write_iteration, :, :, :] = positions - - # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments - velocities = np.zeros([n_replicas, n_particles, 3]) - for replica_index, sampler_state in enumerate(sampler_states): - if sampler_state._unitless_velocities is not None: - # Store velocities in memory first - x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities - velocities[replica_index, :, :] = x[:, :] - # Store velocites - # TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored - # sampler_state different from origin. - if 'velocities' not in storage.variables: - # create variable with expected dimensions and shape - storage.createVariable('velocities', storage.variables['positions'].dtype, - dimensions=storage.variables['positions'].dimensions) - storage.variables['velocities'][write_iteration, :, :, :] = velocities - - if is_periodic: + if write_pos: + # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments + positions = np.zeros([n_replicas, n_particles, 3]) + for replica_index, sampler_state in enumerate(sampler_states): + # Store positions in memory first + x = sampler_state.positions / unit.nanometers + positions[replica_index, :, :] = x[:, :] + # Store positions + storage.variables['positions'][write_iteration, :, :, :] = positions + + if write_vel: + # Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments + velocities = np.zeros([n_replicas, n_particles, 3]) + for replica_index, sampler_state in enumerate(sampler_states): + if sampler_state._unitless_velocities is not None: + # Store velocities in memory first + x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities + velocities[replica_index, :, :] = x[:, :] + # Store velocites + # TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored + # sampler_state different from origin. + if 'velocities' not in storage.variables: + # create variable with expected dimensions and shape + storage.createVariable('velocities', storage.variables['positions'].dtype, + dimensions=storage.variables['positions'].dimensions) + storage.variables['velocities'][write_iteration, :, :, :] = velocities + + if is_periodic and write_pos: # Store box vectors and volume. # Allocate whole write to memory first box_vectors = np.zeros([n_replicas, 3, 3]) @@ -1727,21 +1763,26 @@ def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoi sampler_states = list() for replica_index in range(n_replicas): # Restore positions. - x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64) - positions = unit.Quantity(x, unit.nanometers) + try: + x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64) + positions = unit.Quantity(x, unit.nanometers) + except (IndexError, KeyError): + positions = np.zeros((storage.dimensions['atom'].size, # TODO: analysis_particles or atom here? + storage.dimensions['spatial'].size), dtype=np.float64) # Restore velocities # try-catch exception, enabling reading legacy/older serialized objects from openmmtools<0.21.3 try: x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64) velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds) - except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2) + except (IndexError, KeyError): # Velocities key/variable not found in serialization (openmmtools<=0.21.2) # pass zeros as velocities when key is not found (<0.21.3 behavior) velocities = np.zeros_like(positions) if 'box_vectors' in storage.variables: # Restore box vectors. x = storage.variables['box_vectors'][read_iteration, replica_index, :, :].astype(np.float64) + # TODO: Are box vectors also variably saved? box_vectors = unit.Quantity(x, unit.nanometers) else: box_vectors = None diff --git a/openmmtools/tests/test_sampling.py b/openmmtools/tests/test_sampling.py index f16188bf..b52796d6 100644 --- a/openmmtools/tests/test_sampling.py +++ b/openmmtools/tests/test_sampling.py @@ -357,20 +357,23 @@ class TestReporter: @staticmethod @contextlib.contextmanager - def temporary_reporter( - checkpoint_interval=1, checkpoint_storage=None, analysis_particle_indices=() - ): + + def temporary_reporter(checkpoint_interval=1, checkpoint_storage=None, + position_interval=1, velocity_interval=1, + analysis_particle_indices=()): """Create and initialize a reporter in a temporary directory.""" with temporary_directory() as tmp_dir_path: storage_file = os.path.join(tmp_dir_path, "temp_dir/test_storage.nc") assert not os.path.isfile(storage_file) - reporter = MultiStateReporter( - storage=storage_file, - open_mode="w", - checkpoint_interval=checkpoint_interval, - checkpoint_storage=checkpoint_storage, - analysis_particle_indices=analysis_particle_indices, - ) + + reporter = MultiStateReporter(storage=storage_file, open_mode='w', + checkpoint_interval=checkpoint_interval, + checkpoint_storage=checkpoint_storage, + analysis_particle_indices=analysis_particle_indices, + position_interval=position_interval, + velocity_interval=velocity_interval, + ) + assert reporter.storage_exists(skip_size=True) yield reporter @@ -561,6 +564,124 @@ def test_write_sampler_states(self): checkpoint_state.box_vectors / unit.nanometer, ) + def test_writer_sampler_states_pos_interval(self): + """ write positions and velocities every other frame""" + analysis_particles = (1, 2) + with self.temporary_reporter(analysis_particle_indices=analysis_particles, + position_interval=2, velocity_interval=2, + checkpoint_interval=2) as reporter: + # Create sampler states. + alanine_test = testsystems.AlanineDipeptideVacuum() + positions = alanine_test.positions + sampler_states = [mmtools.states.SamplerState(positions=positions) + for _ in range(2)] + + # Check that after writing and reading, states are identical. + for iteration in range(3): + reporter.write_sampler_states(sampler_states, iteration=iteration) + reporter.write_last_iteration(iteration) + + # Check first frame + restored_sampler_states = reporter.read_sampler_states(iteration=0) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert np.allclose(state.positions, restored_state.positions) + # By default stored velocities are zeros if not present in origin sampler_state + assert np.allclose(np.zeros(state.positions.shape), restored_state.velocities) + assert np.allclose(state.box_vectors / unit.nanometer, restored_state.box_vectors / unit.nanometer) + # Second frame should not have positions or velocities + restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + # missing values are returned as numpy masked array + # so we check that these arrays are all masked + assert restored_state.positions._value.mask.all() + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert np.allclose(state.positions[analysis_particles, :], restored_state.positions) + # By default stored velocities are zeros if not present in origin sampler_state + assert np.allclose(np.zeros((2, 3)), restored_state.velocities) + assert np.allclose(state.box_vectors / unit.nanometer, restored_state.box_vectors / unit.nanometer) + + def test_write_sampler_states_no_vel(self): + """do not write velocities to trajectory file""" + analysis_particles = (1, 2) + with self.temporary_reporter(analysis_particle_indices=analysis_particles, + position_interval=1, velocity_interval=0, + checkpoint_interval=2) as reporter: + # Create sampler states. + alanine_test = testsystems.AlanineDipeptideVacuum() + positions = alanine_test.positions + sampler_states = [mmtools.states.SamplerState(positions=positions) + for _ in range(2)] + + # Check that after writing and reading, states are identical. + for iteration in range(3): + reporter.write_sampler_states(sampler_states, iteration=iteration) + reporter.write_last_iteration(iteration) + + # Check first frame + restored_sampler_states = reporter.read_sampler_states(iteration=0, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + # missing values are returned as numpy masked array + # so we check that these arrays are all masked + assert np.allclose(state.positions[analysis_particles, :], restored_state.positions) + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + # Second frame should not have positions or velocities + restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert np.allclose(state.positions[analysis_particles, :], restored_state.positions) + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert np.allclose(state.positions[analysis_particles, :], restored_state.positions) + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + def test_write_sampler_states_no_pos(self): + """do not write positions or velocities to trajectory file""" + analysis_particles = (1, 2) + with self.temporary_reporter(analysis_particle_indices=analysis_particles, + position_interval=0, velocity_interval=0, + checkpoint_interval=2) as reporter: + # Create sampler states. + alanine_test = testsystems.AlanineDipeptideVacuum() + positions = alanine_test.positions + sampler_states = [mmtools.states.SamplerState(positions=positions) + for _ in range(2)] + + # Check that after writing and reading, states are identical. + for iteration in range(3): + reporter.write_sampler_states(sampler_states, iteration=iteration) + reporter.write_last_iteration(iteration) + + # Check first frame + restored_sampler_states = reporter.read_sampler_states(iteration=0, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + # missing values are returned as numpy masked array + # so we check that these arrays are all masked + assert restored_state.positions._value.mask.all() + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + # Second frame should not have positions or velocities + restored_sampler_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert restored_state.positions._value.mask.all() + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + + restored_sampler_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True) + for state, restored_state in zip(sampler_states, restored_sampler_states): + assert restored_state.positions._value.mask.all() + assert restored_state.velocities._value.mask.all() + assert restored_state.box_vectors is None # not periodic + def test_analysis_particle_mismatch(self): """Test that previously stored analysis particles is higher priority.""" blank_analysis_particles = ()