diff --git a/openmmtools/multistate/sams.py b/openmmtools/multistate/sams.py index 8884bbcb8..0774d540f 100644 --- a/openmmtools/multistate/sams.py +++ b/openmmtools/multistate/sams.py @@ -174,9 +174,14 @@ def __init__(self, update_stages='two-stage', flatness_criteria='logZ-flatness', flatness_threshold=0.2, + minimum_visits=10, + flatness_criteria2='logZ-flatness', + flatness_threshold2=None, + minimum_visits2=None, weight_update_method='rao-blackwellized', adapt_target_probabilities=False, gamma0=1.0, + beta_factor=0.8, logZ_guess=None, **kwargs): """Initialize a SAMS sampler. @@ -196,14 +201,24 @@ def __init__(self, locality : int, optional, default=1 Number of neighboring states on either side to consider for local update schemes. update_stages : str, optional, default='two-stage' - One of ['one-stage', 'two-stage'] + One of ['one-stage', 'two-stage','three-stage'] ``one-stage`` will use the asymptotically optimal scheme throughout the entire simulation (not recommended due to slow convergence) ``two-stage`` will use a heuristic first stage to achieve flat histograms before switching to the asymptotically optimal scheme - flatness_criteria : string, optiona, default='logZ-flatness' + ``three-stage`` is two-stage, but with an additional final stage where weights are not updated + flatness_criteria : string, optional, default='logZ-flatness' Method of assessing when to switch to asymptotically optimal scheme One of ['logZ-flatness','minimum-visits','histogram-flatness'] flatness_threshold : float, optional, default=0.2 Histogram relative flatness threshold to use for first stage of two-stage scheme. + minimum_visits : int, optional, default=10 + Minimum visits per states threshold to use for second stage of three-stage scheme. + flatness_criteria2 : string, optional, default='logZ-flatness' + Method of assessing when to switch to asymptotically optimal scheme for second stage of three-stage scheme. + One of ['logZ-flatness','minimum-visits','histogram-flatness'] + flatness_threshold2 : float, optional, default=None + Histogram relative flatness threshold to use for second stage of three-stage scheme. + minimum_visits2 : int, optional, default=None + Minimum visits per states threshold to use for second stage of three-stage scheme. weight_update_method : str, optional, default='rao-blackwellized' Method to use for updating log weights in SAMS. One of ['optimal', 'rao-blackwellized'] ``rao-blackwellized`` will update log free energy estimate for all states for which energies were computed @@ -211,9 +226,11 @@ def __init__(self, adapt_target_probabilities : bool, optional, default=False If True, target probabilities will be adapted to achieve minimal thermodynamic length between terminal thermodynamic states. (EXPERIMENTAL) - gamma0 : float, optional, default=0.0 + gamma0 : float, optional, default=1.0 Initial weight adaptation rate. - logZ_guess : array-like of shape [n_states] of floats, optiona, default=None + beta_factor : float, optional, default=0.8 + The decay factor of the weight adaption rate + logZ_guess : array-like of shape [n_states] of floats, optional, default=None Initial guess for logZ for all states, if available. """ # Initialize multi-state sampler @@ -225,9 +242,14 @@ def __init__(self, self.update_stages = update_stages self.flatness_criteria = flatness_criteria self.flatness_threshold = flatness_threshold + self.minimum_visits = minimum_visits + self.flatness_criteria2 = flatness_criteria2 + self.flatness_threshold2 = flatness_threshold2 + self.minimum_visits2 = minimum_visits2 self.weight_update_method = weight_update_method self.adapt_target_probabilities = adapt_target_probabilities self.gamma0 = gamma0 + self.beta_factor = beta_factor self.logZ_guess = logZ_guess # Private variables # self._replica_neighbors[replica_index] is a list of states that form the neighborhood of ``replica_index`` @@ -247,7 +269,7 @@ def _state_update_scheme_validator(instance, scheme): @staticmethod def _update_stages_validator(instance, scheme): - supported_schemes = ['one-stage', 'two-stage'] + supported_schemes = ['one-stage', 'two-stage','three-stage'] if scheme not in supported_schemes: raise ValueError("Unknown update scheme '{}'. Supported values " "are {}.".format(scheme, supported_schemes)) @@ -277,22 +299,48 @@ def _adapt_target_probabilities_validator(instance, scheme): "are {}.".format(scheme, supported_schemes)) return scheme + log_target_probabilities = _StoredProperty('log_target_probabilities', validate_function=None) state_update_scheme = _StoredProperty('state_update_scheme', validate_function=_StoredProperty._state_update_scheme_validator) locality = _StoredProperty('locality', validate_function=None) update_stages = _StoredProperty('update_stages', validate_function=_StoredProperty._update_stages_validator) flatness_criteria = _StoredProperty('flatness_criteria', validate_function=_StoredProperty._flatness_criteria_validator) flatness_threshold = _StoredProperty('flatness_threshold', validate_function=None) + minimum_visits = _StoredProperty('minimum_visits', validate_function=None) + flatness_criteria2 = _StoredProperty('flatness_criteria2', validate_function=_StoredProperty._flatness_criteria_validator) #same set of flatness criteria available + flatness_threshold2 = _StoredProperty('flatness_threshold2', validate_function=None) + minimum_visits2 = _StoredProperty('minimum_visits2', validate_function=None) weight_update_method = _StoredProperty('weight_update_method', validate_function=_StoredProperty._weight_update_method_validator) adapt_target_probabilities = _StoredProperty('adapt_target_probabilities', validate_function=_StoredProperty._adapt_target_probabilities_validator) gamma0 = _StoredProperty('gamma0', validate_function=None) logZ_guess = _StoredProperty('logZ_guess', validate_function=None) + + def _check_stage_options(self): + # if two-stage check no three-stage parameters have been provided + if self.update_stages == 'two-stage': + assert (self.flatness_threshold2 == None) , "If running two-stage SAMS, flatness_threshold2 should not be defined" + assert (self.minimum_visits2 == None) , "If running two-stage SAMS, minimum_visits2 should not be defined" + # if requirement for two and three stages are the same type, check they are tight + if self.update_stages == 'three-stage': + if self.flatness_criteria == self.flatness_criteria2: + logger.debug(f'Flatness criteria for 2nd and 3rd stages are the same: {self.flatness_criteria}') + if self.flatness_criteria == 'minimum-visits': + assert ( self.minimum_visits2 > self.minimum_visits ) , "minimum_visits2 must be larger than minimum_visits for three-stage SAMS" + elif self.flatness_criteria in ['logZ-flatness','histogram-flatness']: + assert ( self.flatness_threshold2 < self.flatness_threshold ) , "flatness_threshold2 must be smaller than flatness_threshold for three-stage SAMS" + # if they're not tight, then throw a warning and proceed with caution + else: # in the case that differing flatness_criteria are used for 2nd and 3rd stage + logger.info(f'Third stage flatness criteria and threshold ({self.flatness_criteria2},{self.flatness_threshold2}) may not be a tighter' + f'requirement than second stage criteria and threshold ({self.flatness_criteria},{self.flatness_threshold}).' + f'If third stage criteria is met before second stage, then simulation will default to \'two-stage\' update scheme') + def _initialize_stage(self): self._t0 = 0 # reference iteration to subtract + self._t1 = 0 # iteration for 3rd stage SAMS if self.update_stages == 'one-stage': self._stage = 1 # start with asymptotically-optimal stage - elif self.update_stages == 'two-stage': + elif self.update_stages in ['two-stage','three-stage']: self._stage = 0 # start with rapid heuristic adaptation initial stage def _pre_write_create(self, thermodynamic_states: list, sampler_states: list, storage, @@ -367,10 +415,11 @@ def _restore_sampler_from_reporter(self, reporter): super()._restore_sampler_from_reporter(reporter) self._cached_state_histogram = self._compute_state_histogram(reporter=reporter) logger.debug('Restored state histogram: {}'.format(self._cached_state_histogram)) - data = reporter.read_online_analysis_data(self._iteration, 'logZ', 'stage', 't0') + data = reporter.read_online_analysis_data(self._iteration, 'logZ', 'stage', 't0','t1') self._logZ = data['logZ'] self._stage = int(data['stage'][0]) self._t0 = int(data['t0'][0]) + self._t1 = int(data['t1'][0]) # Compute log weights from log target probability and logZ estimate self._update_log_weights() @@ -383,7 +432,7 @@ def _restore_sampler_from_reporter(self, reporter): def _report_iteration_items(self): super(SAMSSampler, self)._report_iteration_items() - self._reporter.write_online_data_dynamic_and_static(self._iteration, logZ=self._logZ, stage=self._stage, t0=self._t0) + self._reporter.write_online_data_dynamic_and_static(self._iteration, logZ=self._logZ, stage=self._stage, t0=self._t0,t1=self._t1) # Split into which states and how many samplers are in each state # Trying to do histogram[replica_thermo_states] += 1 does not correctly handle multiple # replicas in the same state. @@ -421,8 +470,7 @@ def _mix_replicas(self): n_swaps_accepted = self._n_accepted_matrix.sum() swap_fraction_accepted = 0.0 if n_swaps_proposed > 0: - # TODO drop casting to float when dropping Python 2 support. - swap_fraction_accepted = float(n_swaps_accepted) / n_swaps_proposed + swap_fraction_accepted = n_swaps_accepted / n_swaps_proposed logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed, swap_fraction_accepted * 100.0)) @@ -430,7 +478,8 @@ def _mix_replicas(self): self._update_logZ_estimates(replicas_log_P_k) # Update log weights based on target probabilities - self._update_log_weights() + if self._stage < 2: # not updating weights in the final third stage + self._update_log_weights() def _local_jump(self, replicas_log_P_k): n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality @@ -562,11 +611,9 @@ def _update_stage(self): Determine which adaptation stage we're in by checking histogram flatness. """ - # TODO: Make minimum_visits a user option - minimum_visits = 1 N_k = self._state_histogram logger.debug(' state histogram counts ({} total): {}'.format(self._cached_state_histogram.sum(), self._cached_state_histogram)) - if (self.update_stages == 'two-stage') and (self._stage == 0): + if (self.update_stages in ['two-stage','three-stage']) and (self._stage == 0): advance = False if N_k.sum() == 0: # No samples yet; don't do anything. @@ -574,7 +621,7 @@ def _update_stage(self): if self.flatness_criteria == 'minimum-visits': # Advance if every state has been visited at least once - if np.all(N_k >= minimum_visits): + if np.all(N_k >= self.minimum_visits): advance = True elif self.flatness_criteria == 'histogram-flatness': # Check histogram flatness @@ -591,7 +638,7 @@ def _update_stage(self): if np.all(criteria): advance = True else: - raise ValueError("Unknown flatness_criteria %s" % flatness_criteria) + raise ValueError("Unknown flatness_criteria %s" % self.flatness_criteria) if advance or ((self._t0 > 0) and (self._iteration > self._t0)): # Histograms are sufficiently flat; switch to asymptotically optimal scheme @@ -599,6 +646,38 @@ def _update_stage(self): # TODO: On resuming, we need to recompute or restore t0, or use some other way to compute it self._t0 = self._iteration - 1 + if (self.update_stages == 'three-stage') and (self._stage == 1): + advance = False + if self.flatness_criteria2 == 'minimum-visits': + # Advance if every state has been visited at least once + if np.all(N_k >= self.minimum_visits2): + advance = True + elif self.flatness_criteria2 == 'histogram-flatness': + # Check histogram flatness + empirical_pi_k = N_k[:] / N_k.sum() + pi_k = np.exp(self.log_target_probabilities) + relative_error_k = np.abs(pi_k - empirical_pi_k) / pi_k + if np.all(relative_error_k < self.flatness_threshold2): + advance = True + elif self.flatness_criteria2 == 'logZ-flatness': + # TODO: Advance to asymptotically optimal scheme when logZ update fractional counts per state exceed threshold + # for all states. + criteria = abs(self._logZ / self.gamma0) > self.flatness_threshold2 + logger.debug('logZ-flatness criteria met (%d total): %s' % (np.sum(criteria), str(np.array(criteria, 'i1')))) + if np.all(criteria): + advance = True + else: + raise ValueError("Unknown flatness_criteria %s" % self.flatness_criteria2) + + if advance: + # firstly need to ensure that third-stage criteria is tighter than second-stage + if self._t0 == self._iteration - 1: #check that t0 was not this iteration + logger.info('Third stage criteria is tighter than second-stage criteria. Reverting to a two-stage simulation') + self.update_stages = 'two-stage' + else: + self._stage = 2 # fixed-weight scheme + self._t1 = self._iteration - 1 + def _update_logZ_estimates(self, replicas_log_P_k): """ Update the logZ estimates according to selected SAMS update method @@ -627,14 +706,12 @@ def _update_logZ_estimates(self, replicas_log_P_k): # Update logZ estimates from all replicas for (replica_index, state_index) in enumerate(self._replica_thermodynamic_states): logger.debug(' Replica %d state %d' % (replica_index, state_index)) - # Compute attenuation factor gamma - beta_factor = 0.8 pi_star = pi_k.min() t = float(self._iteration) if self._stage == 0: # initial stage - gamma = self.gamma0 * min(pi_star, t**(-beta_factor)) # Eq. 15 of [1] + gamma = self.gamma0 * min(pi_star, t**(-self.beta_factor)) # Eq. 15 of [1] elif self._stage == 1: - gamma = self.gamma0 * min(pi_star, (t - self._t0 + self._t0**beta_factor)**(-1)) # Eq. 15 of [1] + gamma = self.gamma0 * min(pi_star, (t - self._t0 + self._t0**self.beta_factor)**(-1)) # Eq. 15 of [1] else: raise Exception('stage {} unknown'.format(self._stage)) @@ -662,7 +739,7 @@ def _update_logZ_estimates(self, replicas_log_P_k): raise Exception('Programming error: Unreachable code') # Subtract off logZ[0] to prevent logZ from growing without bound once we reach the asymptotically optimal stage - if self._stage == 1: # asymptotically optimal or one-stage + if self._stage >= 1: # asymptotically optimal or one-stage self._logZ[:] -= self._logZ[0] # Format logZ