From 73776f5b28f8f40927a836278ed9381b48b46aa3 Mon Sep 17 00:00:00 2001 From: wiederm Date: Fri, 29 Dec 2023 16:25:13 +0100 Subject: [PATCH] Refactor MCMC moves and update move set class --- chiron/mcmc.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/chiron/mcmc.py b/chiron/mcmc.py index a767f6c..c67730e 100644 --- a/chiron/mcmc.py +++ b/chiron/mcmc.py @@ -38,15 +38,14 @@ """ from chiron.states import SamplerState, ThermodynamicState -from chiron.potential import NeuralNetworkPotential from openmm import unit from loguru import logger as log -from typing import Dict, Union, Tuple, List, Optional +from typing import Tuple, List, Optional import jax.numpy as jnp from chiron.reporters import SimulationReporter -class StateUpdateMove: +class MCMCMove: def __init__(self, nr_of_moves: int, seed: int): """ Initialize a move within the molecular system. @@ -64,7 +63,7 @@ def __init__(self, nr_of_moves: int, seed: int): self.key = jrandom.PRNGKey(seed) # 'seed' is an integer seed value -class LangevinDynamicsMove(StateUpdateMove): +class LangevinDynamicsMove(MCMCMove): def __init__( self, stepsize=1.0 * unit.femtoseconds, @@ -110,6 +109,13 @@ def run( state_variables (StateVariablesCollection): State variables of the system. """ + assert isinstance( + sampler_state, SamplerState + ), f"Sampler state must be SamplerState, not {type(sampler_state)}" + assert isinstance( + thermodynamic_state, ThermodynamicState + ), f"Thermodynamic state must be ThermodynamicState, not {type(thermodynamic_state)}" + self.integrator.run( thermodynamic_state=thermodynamic_state, sampler_state=sampler_state, @@ -118,7 +124,7 @@ def run( ) -class MCMove(StateUpdateMove): +class MCMove(MCMCMove): def __init__(self, nr_of_moves: int, seed: int) -> None: super().__init__(nr_of_moves, seed) @@ -247,7 +253,7 @@ class MoveSet: def __init__( self, - move_schedule: List[Tuple[str, StateUpdateMove]], + move_schedule: List[Tuple[str, MCMCMove]], ) -> None: _AVAILABLE_MOVES = ["LangevinDynamicsMove"] self.move_schedule = move_schedule @@ -264,7 +270,7 @@ def _validate_sequence(self): If a move in the sequence is not present in available_moves. """ for move_name, move_class in self.move_schedule: - if not isinstance(move_class, StateUpdateMove): + if not isinstance(move_class, MCMCMove): raise ValueError(f"Move {move_name} in the sequence is not available.")