Skip to content

Commit

Permalink
Refactor MCMC moves and update move set class
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Dec 29, 2023
1 parent 409f5c7 commit 73776f5
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions chiron/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
"""
from chiron.states import SamplerState, ThermodynamicState
from chiron.potential import NeuralNetworkPotential
from openmm import unit
from loguru import logger as log
from typing import Dict, Union, Tuple, List, Optional
from typing import Tuple, List, Optional
import jax.numpy as jnp
from chiron.reporters import SimulationReporter


class StateUpdateMove:
class MCMCMove:
def __init__(self, nr_of_moves: int, seed: int):
"""
Initialize a move within the molecular system.
Expand All @@ -64,7 +63,7 @@ def __init__(self, nr_of_moves: int, seed: int):
self.key = jrandom.PRNGKey(seed) # 'seed' is an integer seed value


class LangevinDynamicsMove(StateUpdateMove):
class LangevinDynamicsMove(MCMCMove):
def __init__(
self,
stepsize=1.0 * unit.femtoseconds,
Expand Down Expand Up @@ -110,6 +109,13 @@ def run(
state_variables (StateVariablesCollection): State variables of the system.
"""

assert isinstance(
sampler_state, SamplerState
), f"Sampler state must be SamplerState, not {type(sampler_state)}"
assert isinstance(
thermodynamic_state, ThermodynamicState
), f"Thermodynamic state must be ThermodynamicState, not {type(thermodynamic_state)}"

self.integrator.run(
thermodynamic_state=thermodynamic_state,
sampler_state=sampler_state,
Expand All @@ -118,7 +124,7 @@ def run(
)


class MCMove(StateUpdateMove):
class MCMove(MCMCMove):
def __init__(self, nr_of_moves: int, seed: int) -> None:
super().__init__(nr_of_moves, seed)

Expand Down Expand Up @@ -247,7 +253,7 @@ class MoveSet:

def __init__(
self,
move_schedule: List[Tuple[str, StateUpdateMove]],
move_schedule: List[Tuple[str, MCMCMove]],
) -> None:
_AVAILABLE_MOVES = ["LangevinDynamicsMove"]
self.move_schedule = move_schedule
Expand All @@ -264,7 +270,7 @@ def _validate_sequence(self):
If a move in the sequence is not present in available_moves.
"""
for move_name, move_class in self.move_schedule:
if not isinstance(move_class, StateUpdateMove):
if not isinstance(move_class, MCMCMove):
raise ValueError(f"Move {move_name} in the sequence is not available.")


Expand Down

0 comments on commit 73776f5

Please sign in to comment.