Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Stage three sams #417

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 98 additions & 21 deletions openmmtools/multistate/sams.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,14 @@ def __init__(self,
update_stages='two-stage',
flatness_criteria='logZ-flatness',
flatness_threshold=0.2,
minimum_visits=10,
flatness_criteria2='logZ-flatness',
flatness_threshold2=None,
minimum_visits2=None,
weight_update_method='rao-blackwellized',
adapt_target_probabilities=False,
gamma0=1.0,
beta_factor=0.8,
logZ_guess=None,
**kwargs):
"""Initialize a SAMS sampler.
Expand All @@ -196,24 +201,36 @@ def __init__(self,
locality : int, optional, default=1
Number of neighboring states on either side to consider for local update schemes.
update_stages : str, optional, default='two-stage'
One of ['one-stage', 'two-stage']
One of ['one-stage', 'two-stage','three-stage']
``one-stage`` will use the asymptotically optimal scheme throughout the entire simulation (not recommended due to slow convergence)
``two-stage`` will use a heuristic first stage to achieve flat histograms before switching to the asymptotically optimal scheme
flatness_criteria : string, optiona, default='logZ-flatness'
``three-stage`` is two-stage, but with an additional final stage where weights are not updated
flatness_criteria : string, optional, default='logZ-flatness'
Method of assessing when to switch to asymptotically optimal scheme
One of ['logZ-flatness','minimum-visits','histogram-flatness']
flatness_threshold : float, optional, default=0.2
Histogram relative flatness threshold to use for first stage of two-stage scheme.
minimum_visits : int, optional, default=10
Minimum visits per states threshold to use for second stage of three-stage scheme.
flatness_criteria2 : string, optional, default='logZ-flatness'
Method of assessing when to switch to asymptotically optimal scheme for second stage of three-stage scheme.
One of ['logZ-flatness','minimum-visits','histogram-flatness']
flatness_threshold2 : float, optional, default=None
Histogram relative flatness threshold to use for second stage of three-stage scheme.
minimum_visits2 : int, optional, default=None
Minimum visits per states threshold to use for second stage of three-stage scheme.
weight_update_method : str, optional, default='rao-blackwellized'
Method to use for updating log weights in SAMS. One of ['optimal', 'rao-blackwellized']
``rao-blackwellized`` will update log free energy estimate for all states for which energies were computed
``optimal`` will use integral counts to update log free energy estimate of current state only
adapt_target_probabilities : bool, optional, default=False
If True, target probabilities will be adapted to achieve minimal thermodynamic length between terminal thermodynamic states.
(EXPERIMENTAL)
gamma0 : float, optional, default=0.0
gamma0 : float, optional, default=1.0
Initial weight adaptation rate.
logZ_guess : array-like of shape [n_states] of floats, optiona, default=None
beta_factor : float, optional, default=0.8
The decay factor of the weight adaption rate
logZ_guess : array-like of shape [n_states] of floats, optional, default=None
Initial guess for logZ for all states, if available.
"""
# Initialize multi-state sampler
Expand All @@ -225,9 +242,14 @@ def __init__(self,
self.update_stages = update_stages
self.flatness_criteria = flatness_criteria
self.flatness_threshold = flatness_threshold
self.minimum_visits = minimum_visits
self.flatness_criteria2 = flatness_criteria2
self.flatness_threshold2 = flatness_threshold2
self.minimum_visits2 = minimum_visits2
self.weight_update_method = weight_update_method
self.adapt_target_probabilities = adapt_target_probabilities
self.gamma0 = gamma0
self.beta_factor = beta_factor
self.logZ_guess = logZ_guess
# Private variables
# self._replica_neighbors[replica_index] is a list of states that form the neighborhood of ``replica_index``
Expand All @@ -247,7 +269,7 @@ def _state_update_scheme_validator(instance, scheme):

@staticmethod
def _update_stages_validator(instance, scheme):
supported_schemes = ['one-stage', 'two-stage']
supported_schemes = ['one-stage', 'two-stage','three-stage']
if scheme not in supported_schemes:
raise ValueError("Unknown update scheme '{}'. Supported values "
"are {}.".format(scheme, supported_schemes))
Expand Down Expand Up @@ -277,22 +299,48 @@ def _adapt_target_probabilities_validator(instance, scheme):
"are {}.".format(scheme, supported_schemes))
return scheme


log_target_probabilities = _StoredProperty('log_target_probabilities', validate_function=None)
state_update_scheme = _StoredProperty('state_update_scheme', validate_function=_StoredProperty._state_update_scheme_validator)
locality = _StoredProperty('locality', validate_function=None)
update_stages = _StoredProperty('update_stages', validate_function=_StoredProperty._update_stages_validator)
flatness_criteria = _StoredProperty('flatness_criteria', validate_function=_StoredProperty._flatness_criteria_validator)
flatness_threshold = _StoredProperty('flatness_threshold', validate_function=None)
minimum_visits = _StoredProperty('minimum_visits', validate_function=None)
flatness_criteria2 = _StoredProperty('flatness_criteria2', validate_function=_StoredProperty._flatness_criteria_validator) #same set of flatness criteria available
flatness_threshold2 = _StoredProperty('flatness_threshold2', validate_function=None)
minimum_visits2 = _StoredProperty('minimum_visits2', validate_function=None)
weight_update_method = _StoredProperty('weight_update_method', validate_function=_StoredProperty._weight_update_method_validator)
adapt_target_probabilities = _StoredProperty('adapt_target_probabilities', validate_function=_StoredProperty._adapt_target_probabilities_validator)
gamma0 = _StoredProperty('gamma0', validate_function=None)
logZ_guess = _StoredProperty('logZ_guess', validate_function=None)


def _check_stage_options(self):
# if two-stage check no three-stage parameters have been provided
if self.update_stages == 'two-stage':
assert (self.flatness_threshold2 == None) , "If running two-stage SAMS, flatness_threshold2 should not be defined"
assert (self.minimum_visits2 == None) , "If running two-stage SAMS, minimum_visits2 should not be defined"
# if requirement for two and three stages are the same type, check they are tight
if self.update_stages == 'three-stage':
if self.flatness_criteria == self.flatness_criteria2:
logger.debug(f'Flatness criteria for 2nd and 3rd stages are the same: {self.flatness_criteria}')
if self.flatness_criteria == 'minimum-visits':
assert ( self.minimum_visits2 > self.minimum_visits ) , "minimum_visits2 must be larger than minimum_visits for three-stage SAMS"
elif self.flatness_criteria in ['logZ-flatness','histogram-flatness']:
assert ( self.flatness_threshold2 < self.flatness_threshold ) , "flatness_threshold2 must be smaller than flatness_threshold for three-stage SAMS"
# if they're not tight, then throw a warning and proceed with caution
else: # in the case that differing flatness_criteria are used for 2nd and 3rd stage
logger.info(f'Third stage flatness criteria and threshold ({self.flatness_criteria2},{self.flatness_threshold2}) may not be a tighter'
f'requirement than second stage criteria and threshold ({self.flatness_criteria},{self.flatness_threshold}).'
f'If third stage criteria is met before second stage, then simulation will default to \'two-stage\' update scheme')

def _initialize_stage(self):
self._t0 = 0 # reference iteration to subtract
self._t1 = 0 # iteration for 3rd stage SAMS
if self.update_stages == 'one-stage':
self._stage = 1 # start with asymptotically-optimal stage
elif self.update_stages == 'two-stage':
elif self.update_stages in ['two-stage','three-stage']:
self._stage = 0 # start with rapid heuristic adaptation initial stage

def _pre_write_create(self, thermodynamic_states: list, sampler_states: list, storage,
Expand Down Expand Up @@ -367,10 +415,11 @@ def _restore_sampler_from_reporter(self, reporter):
super()._restore_sampler_from_reporter(reporter)
self._cached_state_histogram = self._compute_state_histogram(reporter=reporter)
logger.debug('Restored state histogram: {}'.format(self._cached_state_histogram))
data = reporter.read_online_analysis_data(self._iteration, 'logZ', 'stage', 't0')
data = reporter.read_online_analysis_data(self._iteration, 'logZ', 'stage', 't0','t1')
self._logZ = data['logZ']
self._stage = int(data['stage'][0])
self._t0 = int(data['t0'][0])
self._t1 = int(data['t1'][0])

# Compute log weights from log target probability and logZ estimate
self._update_log_weights()
Expand All @@ -383,7 +432,7 @@ def _restore_sampler_from_reporter(self, reporter):
def _report_iteration_items(self):
super(SAMSSampler, self)._report_iteration_items()

self._reporter.write_online_data_dynamic_and_static(self._iteration, logZ=self._logZ, stage=self._stage, t0=self._t0)
self._reporter.write_online_data_dynamic_and_static(self._iteration, logZ=self._logZ, stage=self._stage, t0=self._t0,t1=self._t1)
# Split into which states and how many samplers are in each state
# Trying to do histogram[replica_thermo_states] += 1 does not correctly handle multiple
# replicas in the same state.
Expand Down Expand Up @@ -421,16 +470,16 @@ def _mix_replicas(self):
n_swaps_accepted = self._n_accepted_matrix.sum()
swap_fraction_accepted = 0.0
if n_swaps_proposed > 0:
# TODO drop casting to float when dropping Python 2 support.
swap_fraction_accepted = float(n_swaps_accepted) / n_swaps_proposed
swap_fraction_accepted = n_swaps_accepted / n_swaps_proposed
logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed,
swap_fraction_accepted * 100.0))

# Update logZ estimates
self._update_logZ_estimates(replicas_log_P_k)

# Update log weights based on target probabilities
self._update_log_weights()
if self._stage < 2: # not updating weights in the final third stage
self._update_log_weights()

def _local_jump(self, replicas_log_P_k):
n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality
Expand Down Expand Up @@ -562,19 +611,17 @@ def _update_stage(self):
Determine which adaptation stage we're in by checking histogram flatness.

"""
# TODO: Make minimum_visits a user option
minimum_visits = 1
N_k = self._state_histogram
logger.debug(' state histogram counts ({} total): {}'.format(self._cached_state_histogram.sum(), self._cached_state_histogram))
if (self.update_stages == 'two-stage') and (self._stage == 0):
if (self.update_stages in ['two-stage','three-stage']) and (self._stage == 0):
advance = False
if N_k.sum() == 0:
# No samples yet; don't do anything.
return

if self.flatness_criteria == 'minimum-visits':
# Advance if every state has been visited at least once
if np.all(N_k >= minimum_visits):
if np.all(N_k >= self.minimum_visits):
advance = True
elif self.flatness_criteria == 'histogram-flatness':
# Check histogram flatness
Expand All @@ -591,14 +638,46 @@ def _update_stage(self):
if np.all(criteria):
advance = True
else:
raise ValueError("Unknown flatness_criteria %s" % flatness_criteria)
raise ValueError("Unknown flatness_criteria %s" % self.flatness_criteria)

if advance or ((self._t0 > 0) and (self._iteration > self._t0)):
# Histograms are sufficiently flat; switch to asymptotically optimal scheme
self._stage = 1 # asymptotically optimal
# TODO: On resuming, we need to recompute or restore t0, or use some other way to compute it
self._t0 = self._iteration - 1

if (self.update_stages == 'three-stage') and (self._stage == 1):
advance = False
if self.flatness_criteria2 == 'minimum-visits':
# Advance if every state has been visited at least once
if np.all(N_k >= self.minimum_visits2):
advance = True
elif self.flatness_criteria2 == 'histogram-flatness':
# Check histogram flatness
empirical_pi_k = N_k[:] / N_k.sum()
pi_k = np.exp(self.log_target_probabilities)
relative_error_k = np.abs(pi_k - empirical_pi_k) / pi_k
if np.all(relative_error_k < self.flatness_threshold2):
advance = True
elif self.flatness_criteria2 == 'logZ-flatness':
# TODO: Advance to asymptotically optimal scheme when logZ update fractional counts per state exceed threshold
# for all states.
criteria = abs(self._logZ / self.gamma0) > self.flatness_threshold2
logger.debug('logZ-flatness criteria met (%d total): %s' % (np.sum(criteria), str(np.array(criteria, 'i1'))))
if np.all(criteria):
advance = True
else:
raise ValueError("Unknown flatness_criteria %s" % self.flatness_criteria2)

if advance:
# firstly need to ensure that third-stage criteria is tighter than second-stage
if self._t0 == self._iteration - 1: #check that t0 was not this iteration
logger.info('Third stage criteria is tighter than second-stage criteria. Reverting to a two-stage simulation')
self.update_stages = 'two-stage'
else:
self._stage = 2 # fixed-weight scheme
self._t1 = self._iteration - 1

def _update_logZ_estimates(self, replicas_log_P_k):
"""
Update the logZ estimates according to selected SAMS update method
Expand Down Expand Up @@ -627,14 +706,12 @@ def _update_logZ_estimates(self, replicas_log_P_k):
# Update logZ estimates from all replicas
for (replica_index, state_index) in enumerate(self._replica_thermodynamic_states):
logger.debug(' Replica %d state %d' % (replica_index, state_index))
# Compute attenuation factor gamma
beta_factor = 0.8
pi_star = pi_k.min()
t = float(self._iteration)
if self._stage == 0: # initial stage
gamma = self.gamma0 * min(pi_star, t**(-beta_factor)) # Eq. 15 of [1]
gamma = self.gamma0 * min(pi_star, t**(-self.beta_factor)) # Eq. 15 of [1]
elif self._stage == 1:
gamma = self.gamma0 * min(pi_star, (t - self._t0 + self._t0**beta_factor)**(-1)) # Eq. 15 of [1]
gamma = self.gamma0 * min(pi_star, (t - self._t0 + self._t0**self.beta_factor)**(-1)) # Eq. 15 of [1]
else:
raise Exception('stage {} unknown'.format(self._stage))

Expand Down Expand Up @@ -662,7 +739,7 @@ def _update_logZ_estimates(self, replicas_log_P_k):
raise Exception('Programming error: Unreachable code')

# Subtract off logZ[0] to prevent logZ from growing without bound once we reach the asymptotically optimal stage
if self._stage == 1: # asymptotically optimal or one-stage
if self._stage >= 1: # asymptotically optimal or one-stage
self._logZ[:] -= self._logZ[0]

# Format logZ
Expand Down