Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Simulation and MultistateSimulation reporter #18

Merged
merged 41 commits into from
Jan 18, 2024
Merged

Conversation

wiederm
Copy link
Member

@wiederm wiederm commented Jan 10, 2024

Description

For a given MCMC sequence we want to obtain statistics, snapshots (set of specified indices) and log properties (e.g. density) for each MCMC sampler. The multistate sampler needs its reporter to provide the potential energies and state indices to calculate free energies. Additionally, the multistate sampler needs to save regular checkpoint files from which it is possible to restart the MCMC chain.

Note: This PR also contains the improved PRNG implementation. Now, the random number stream is a broadcaster in the sampler state of each MCMC move. This was necessary to ensure that random number streams were separated, which became an issue during this PR since trajectories were synchronizing.

Todos

  • BaseReporter class
  • LangevinReporter
  • MultistateReporter

Status

  • Ready to go

@wiederm wiederm changed the base branch from main to multistage January 10, 2024 14:23
@wiederm wiederm marked this pull request as draft January 10, 2024 14:24
@wiederm wiederm marked this pull request as ready for review January 15, 2024 07:56
@wiederm wiederm self-assigned this Jan 15, 2024
@wiederm wiederm added the enhancement New feature or request label Jan 15, 2024
Copy link
Member

@chrisiacovella chrisiacovella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things about how we are implementing the RNG in sampler_state and I think we need to only be opening the HDF5 when writing data (and then closing when done) to avoid corrupt data and allow examination during a run

cls._key = random.PRNGKey(seed)

@classmethod
def get_random_key(cls) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to use this as a wrapper, this should also probably take in an optional number of keys to split.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But bigger question: do we want to have this separate wrapper or do we want to just keep this functionality within the SamplerState class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that there might be a better way to do this. I wanted to have a solution that (a) requires only a single random seed that the user provides for any number of SamplerState and (b) passes the responsibility for the PRN stream to the SamplerState (which then can live on different machines). This lets us reproduce (if so desired) the random numbers consistently, but also ensures that each SamplerState has an unique PRN stream.

chiron/states.py Outdated
@@ -22,6 +23,7 @@ class SamplerState:
def __init__(
self,
x0: unit.Quantity,
random_seed: random.PRNGKey,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random.PRNGKey is the current state of the sampler...hence I don't think random_seed is the current name. This should probably be called current_key or current_PRNG_key.

I think the sampler state init function should accept two things:

  • random_seed that would be used to generate the first key.
  • current_key which would just set self._current_key.
    where if current_key is defined, we will ignore the random_seed if set, or throw an error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html, seed and key are synonyms :-)
But, I agree that it is clearer if we call it current_PRNG_key.

I would prefer to only pass current_PRNG_key and use it to generate the first key of the new stream. Is there a use case in which we would like to manually set the current_key?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree they are interchangeable in jax, but I think confusion could come up from how people normally use seed in a stateless RNG (i.e., seed setting up your initial state, not something called all the time). I think calling it current_PRNG_key is fine.

The only use case I can think of where we'd want to reset this manually during time would be for restarting a workflow, but even then it would be dependent on the order in which things are read in and initialize (like, the code could execute the same script as before but automatically look in the cached directory for a restart file, in which case the internal classes would already be initialized, but just need their current states and stats updated). I'm not sure it is necessary right now but it is easy to add that code in depending on how we implement other things.

chiron/mcmc.py Outdated
@@ -420,7 +411,7 @@ def apply(
log.debug(
f"Move accepted. Energy change: {delta_energy:.3f} kT. Number of accepted moves: {self.n_accepted}."
)
reporter.report(
self.reporter.report(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some output frequency logic here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Even though I think we want to log every successful MCMove, so the report_frequency should be 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure there are cases where we would not want to log every move (certainly not in a lot of the simple test cases), but setting default to be 1 is probably reasonable

self.apply(
thermodynamic_state, sampler_state, self.simulation_reporter, nbr_list
)
self.apply(thermodynamic_state, sampler_state, nbr_list)
if trials % 100 == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for now, because the MC moves are getting a bit refactoring in the #14 PR, but this will cause a bit of "conflict" as the "apply" function is also writing to the reporter, but at a different frequency, and this is writing out duplicate and potentially less useful information than in apply.

self.buffer = {}
self.h5file = h5py.File(filename, "a")
log.info(f"Writing simulation data to {filename}")
self.h5file = h5py.File(self.file_path, "a")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to change this to be where the file is only opened when writing to it. HDF5 files can be rather grumpy. If we kill the simulation and the file is still open, it may not be readable (give an error trying to open it). Also, unlike a text file you can't even make a temporary copy of an open HDF5 file to then examine the temporary one (also throws an error being open)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good point! I have switched to context manager for the read-and-write operations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not seeing that change to the context manager in the code, did you forget to commit that change?

…tems.py, chiron/tests/test_multistate.py, chiron/tests/test_potential.py, Examples/LJ_langevin.py, chiron/tests/test_integrators.py, chiron/tests/test_convergence_tests.py, chiron/tests/test_minization.py, chiron/multistate.py, chiron/states.py, chiron/tests/test_pairs.py
Copy link
Member

@chrisiacovella chrisiacovella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good and can be merged.

@wiederm wiederm merged commit ab0d114 into multistage Jan 18, 2024
@wiederm wiederm mentioned this pull request Jan 18, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants