-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Simulation and MultistateSimulation reporter #18
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things about how we are implementing the RNG in sampler_state and I think we need to only be opening the HDF5 when writing data (and then closing when done) to avoid corrupt data and allow examination during a run
cls._key = random.PRNGKey(seed) | ||
|
||
@classmethod | ||
def get_random_key(cls) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to use this as a wrapper, this should also probably take in an optional number of keys to split.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But bigger question: do we want to have this separate wrapper or do we want to just keep this functionality within the SamplerState class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that there might be a better way to do this. I wanted to have a solution that (a) requires only a single random seed that the user provides for any number of SamplerState and (b) passes the responsibility for the PRN stream to the SamplerState (which then can live on different machines). This lets us reproduce (if so desired) the random numbers consistently, but also ensures that each SamplerState has an unique PRN stream.
chiron/states.py
Outdated
@@ -22,6 +23,7 @@ class SamplerState: | |||
def __init__( | |||
self, | |||
x0: unit.Quantity, | |||
random_seed: random.PRNGKey, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random.PRNGKey is the current state of the sampler...hence I don't think random_seed
is the current name. This should probably be called current_key
or current_PRNG_key
.
I think the sampler state init
function should accept two things:
- random_seed that would be used to generate the first key.
- current_key which would just set self._current_key.
where if current_key is defined, we will ignore the random_seed if set, or throw an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html, seed
and key
are synonyms :-)
But, I agree that it is clearer if we call it current_PRNG_key
.
I would prefer to only pass current_PRNG_key
and use it to generate the first key of the new stream. Is there a use case in which we would like to manually set the current_key
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree they are interchangeable in jax, but I think confusion could come up from how people normally use seed in a stateless RNG (i.e., seed setting up your initial state, not something called all the time). I think calling it current_PRNG_key
is fine.
The only use case I can think of where we'd want to reset this manually during time would be for restarting a workflow, but even then it would be dependent on the order in which things are read in and initialize (like, the code could execute the same script as before but automatically look in the cached directory for a restart file, in which case the internal classes would already be initialized, but just need their current states and stats updated). I'm not sure it is necessary right now but it is easy to add that code in depending on how we implement other things.
chiron/mcmc.py
Outdated
@@ -420,7 +411,7 @@ def apply( | |||
log.debug( | |||
f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}." | |||
) | |||
reporter.report( | |||
self.reporter.report( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some output frequency logic here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Even though I think we want to log every successful MCMove, so the report_frequency
should be 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure there are cases where we would not want to log every move (certainly not in a lot of the simple test cases), but setting default to be 1 is probably reasonable
self.apply( | ||
thermodynamic_state, sampler_state, self.simulation_reporter, nbr_list | ||
) | ||
self.apply(thermodynamic_state, sampler_state, nbr_list) | ||
if trials % 100 == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine for now, because the MC moves are getting a bit refactoring in the #14 PR, but this will cause a bit of "conflict" as the "apply" function is also writing to the reporter, but at a different frequency, and this is writing out duplicate and potentially less useful information than in apply.
chiron/reporters.py
Outdated
self.buffer = {} | ||
self.h5file = h5py.File(filename, "a") | ||
log.info(f"Writing simulation data to {filename}") | ||
self.h5file = h5py.File(self.file_path, "a") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to change this to be where the file is only opened when writing to it. HDF5 files can be rather grumpy. If we kill the simulation and the file is still open, it may not be readable (give an error trying to open it). Also, unlike a text file you can't even make a temporary copy of an open HDF5 file to then examine the temporary one (also throws an error being open)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good point! I have switched to context manager for the read-and-write operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing that change to the context manager in the code, did you forget to commit that change?
…tems.py, chiron/tests/test_multistate.py, chiron/tests/test_potential.py, Examples/LJ_langevin.py, chiron/tests/test_integrators.py, chiron/tests/test_convergence_tests.py, chiron/tests/test_minization.py, chiron/multistate.py, chiron/states.py, chiron/tests/test_pairs.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this looks good and can be merged.
Description
For a given MCMC sequence we want to obtain statistics, snapshots (set of specified indices) and log properties (e.g. density) for each MCMC sampler. The multistate sampler needs its reporter to provide the potential energies and state indices to calculate free energies. Additionally, the multistate sampler needs to save regular checkpoint files from which it is possible to restart the MCMC chain.
Note: This PR also contains the improved PRNG implementation. Now, the random number stream is a broadcaster in the sampler state of each MCMC move. This was necessary to ensure that random number streams were separated, which became an issue during this PR since trajectories were synchronizing.
Todos
Status