From 4808fc4ac92f11fddb528ea9cd80f035f0e19381 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 19 Dec 2023 18:34:00 +0100 Subject: [PATCH 1/6] update minimization tests --- chiron/tests/test_minization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index 157df1e..7b90f39 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -14,7 +14,7 @@ def test_minimization(): # initialize potential from chiron.potential import LJPotential - cutoff = unit.Quantity(1., unit.nanometer) + cutoff = unit.Quantity(1.0, unit.nanometer) lj_potential = LJPotential(lj_fluid.topology, cutoff=cutoff) sampler_state = SamplerState( @@ -30,6 +30,8 @@ def test_minimization(): print(lj_potential.compute_energy(sampler_state.x0, nbr_list)) print(lj_potential.compute_energy(sampler_state.x0)) - min_x = minimize_energy(sampler_state.x0, lj_potential.compute_energy, nbr_list) + min_x = minimize_energy( + sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=10_000 + ) e = lj_potential.compute_energy(min_x, nbr_list) - assert jnp.isclose(e, -12506.332) + assert jnp.isclose(e, -13332.688, atol=1) From ba4699544cc09291a397489b84a4d2bebf766d42 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Tue, 19 Dec 2023 14:51:30 -0800 Subject: [PATCH 2/6] Added in tests for pair list. --- chiron/neighbors.py | 107 +++++++++------- .../{test_neighborlist.py => test_pairs.py} | 116 +++++++++++++++++- 2 files changed, 174 insertions(+), 49 deletions(-) rename chiron/tests/{test_neighborlist.py => test_pairs.py} (68%) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 375f493..995b795 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -641,6 +641,7 @@ def check(self, coordinates: jnp.array) -> bool: """ Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference coordinates. If a particle moves more than 0.5 skin distance, the neighborlist will be rebuilt. + Will also return True if the size of the coordinates array changes. Note, this could also accept a user defined criteria for distance, but this is not implemented yet. @@ -653,6 +654,10 @@ def check(self, coordinates: jnp.array) -> bool: bool True if the neighbor list needs to be rebuilt, False if it does not. """ + + if self.ref_coordinates.shape[0] != coordinates.shape[0]: + return True + status = jax.vmap( self._calculate_particle_displacement, in_axes=(0, None, None) )(self.particle_ids, coordinates, self.ref_coordinates) @@ -673,22 +678,30 @@ class PairList(PairsBase): space: Space Class that defines how to calculate the displacement between two points and apply the boundary conditions cutoff: float, default = 2.5 - Cutoff distance for the neighborlist - n_max_neighbors: int, default=200 - Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations - This will be checked and dynamically updated during the build stage + Cutoff distance for the pair list calculation Examples -------- - - + >>> from chiron.neighbors import PairList, OrthogonalPeriodicSpace + >>> from chiron.states import SamplerState + >>> import jax.numpy as jnp + >>> + >>> space = OrthogonalPeriodicSpace() + >>> pair_list = PairList(space, cutoff=2.5) + >>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]), + >>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]])) + >>> pair_list.build_from_state(sampler_state) + >>> + >>> # mask and distances are of shape (n_particles, n_particles-1), + >>> displacement_vectors of shape (n_particles, n_particles-1, 3) + >>> # mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not + >>> # n_pairs is of shape (n_particles) and is per row sum of the mask. The mask ensure we also do not double count pairs + >>> n_pairs, mask, distances, displacement_vectors = pair_list.calculate(sampler_state.x0) """ def __init__( self, space: Space, cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer), - skin: unit.Quantity = unit.Quantity(0.4, unit.nanometer), - n_max_neighbors: float = 200, ): if not isinstance(space, Space): raise TypeError(f"space must be of type Space, found {type(space)}") @@ -696,15 +709,8 @@ def __init__( raise ValueError( f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}" ) - if not skin.unit.is_compatible(unit.angstrom): - raise ValueError( - f"cutoff must be a unit.Quantity with units of distance, skin.unit = {skin.unit}" - ) self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system) - self.skin = skin.value_in_unit_system(unit.md_unit_system) - self.cutoff_and_skin = self.cutoff + self.skin - self.n_max_neighbors = n_max_neighbors self.space = space # set a a simple variable to know if this has at least been built once as opposed to just initialized @@ -716,7 +722,8 @@ def __init__( @partial(jax.jit, static_argnums=(0,)) def _pairs_and_mask(self, particle_ids: jnp.array): """ - Jitted function to generate mask that allows us to remove self-interactions and double-counting of pairs + Jitted function to generate all pairs (excluding self interactions) + and mask that allows us to remove double-counting of pairs. Parameters ---------- @@ -725,8 +732,10 @@ def _pairs_and_mask(self, particle_ids: jnp.array): Returns ------- - jnp.array - Bool mask to remove self-interactions and double-counting of pairs + all_pairs: jnp.array + Array of all pairs (excluding self interactions), of size (n_particles, n_particles-1) + reduction_mask: jnp.array + Bool mask that identifies which pairs to exclude to remove double counting of pairs """ # for the nsq approach, we consider the distance between a particle and all other particles in the system @@ -748,9 +757,9 @@ def _pairs_and_mask(self, particle_ids: jnp.array): particles_j, temp_mask ) - temp_mask = jnp.where(particles_i < all_pairs[0], True, False) + reduction_mask = jnp.where(particles_i < all_pairs[0], True, False) - return all_pairs[0], temp_mask + return all_pairs[0], reduction_mask @partial(jax.jit, static_argnums=(0,)) def _remove_self_interactions(self, particles, temp_mask): @@ -769,7 +778,7 @@ def build( Parameters ---------- coordinates: jnp.array - Shape[N,3] array of particle coordinates + Shape[n_particles,3] array of particle coordinates box_vectors: jnp.array Shape[3,3] array of box vectors @@ -800,7 +809,7 @@ def build( f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}" ) - self.ref_coordinates = coordinates + self.n_particles = coordinates.shape[0] self.box_vectors = box_vectors # the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list @@ -808,15 +817,10 @@ def build( self.space.box_vectors = self.box_vectors # store the ids of all the particles - self.particle_ids = jnp.array( - range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint16 - ) + self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint16) # calculate which pairs to exclude self.all_pairs, self.reduction_mask = self._pairs_and_mask(self.particle_ids) - self.n_max_neighbors = self.particle_ids.shape[0] - 1 - - # temp_mask = particles_i < new_particles self.is_built = True @@ -825,30 +829,34 @@ def _calc_distance_per_particle( self, particle1, neighbors, neighbor_mask, coordinates ): """ - Jitted function to calculate the distance between a particle and its neighbors + Jitted function to calculate the distance between a particle and all possible neighbors Parameters ---------- particle1: int Particle id neighbors: jnp.array - Array of particle ids for the neighbors of particle1 + Array of particle ids for the possible particle pairs of particle1 neighbor_mask: jnp.array - Mask to exclude padding from the neighbor list of particle1 + Mask to exclude double particles to prevent double counting coordinates: jnp.array - X,Y,Z coordinates of all particles + X,Y,Z coordinates of all particles, shaped (n_particles, 3) Returns ------- n_pairs: int Number of interacting pairs for the particle mask: jnp.array - Mask to exclude padding from the neighbor list of particle1. + Mask to exclude padding particles not within the cutoff particle1. If a particle is within the interaction cutoff, the mask is 1, otherwise it is 0 + Array has shape (n_particles, n_particles-1) as it excludes self interactions dist: jnp.array - Array of distances between the particle and its neighbors + Array of distances between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1) as it excludes self interactions r_ij: jnp.array - Array of displacement vectors between the particle and its neighbors + Array of displacement vectors between the particle and all other particles in the system. + Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions + """ # repeat the particle id for each neighbor particles1 = jnp.repeat(particle1, neighbors.shape[0]) @@ -873,33 +881,35 @@ def calculate(self, coordinates: jnp.array): Parameters ---------- coordinates: jnp.array - Shape[N,3] array of particle coordinates + Shape[n_particles,3] array of particle coordinates Returns ------- n_neighbors: jnp.array - Array of number of neighbors for each particle + Array of number of interacting particles for each particle padding_mask: jnp.array - Array of masks to exclude padding from the neighbor list of each particle + Array used to masks non interaction particle pairs, dist: jnp.array - Array of distances between each particle and its neighbors + Array of distances between each particle and all other particles in the system r_ij: jnp.array - Array of displacement vectors between each particle and its neighbors + Array of displacement vectors between each particle and all other particles in the system. """ - # coordinates = sampler_state.x0 - # note, we assume the box vectors do not change between building and calculating the neighbor list - # changes to the box vectors require rebuilding the neighbor list + if coordinates.shape[0] != self.n_particles: + raise ValueError( + f"Number of particles cannot changes without rebuilding. " + f"Coordinates must have shape ({n_particles}, 3), found {coordinates.shape}" + ) n_neighbors, padding_mask, dist, r_ij = jax.vmap( self._calc_distance_per_particle, in_axes=(0, 0, 0, None) )(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates) - # mask = mask.reshape(-1, self.n_max_neighbors) return n_neighbors, padding_mask, dist, r_ij def check(self, coordinates: jnp.array) -> bool: """ - Always returns false, as there are no internal lists to be rebuilt for a pairlist + Check if we need to reconstruct internal arrays. + For a simple pairlist this will always return False, unless the number of particles change. Parameters ---------- @@ -908,6 +918,9 @@ def check(self, coordinates: jnp.array) -> bool: Returns ------- bool - True if the neighbor list needs to be rebuilt, False if it does not. + True if we need to rebuild the neighbor list, False if we do not. """ - return False + if coordinates.shape[0] != self.n_particles: + return True + else: + return False diff --git a/chiron/tests/test_neighborlist.py b/chiron/tests/test_pairs.py similarity index 68% rename from chiron/tests/test_neighborlist.py rename to chiron/tests/test_pairs.py index 01e66d7..79bea7f 100644 --- a/chiron/tests/test_neighborlist.py +++ b/chiron/tests/test_pairs.py @@ -2,6 +2,7 @@ import pytest from chiron.neighbors import ( NeighborListNsqrd, + PairList, OrthogonalPeriodicSpace, OrthogonalNonperiodicSpace, ) @@ -67,7 +68,7 @@ def test_orthogonal_nonperiodic_displacement(): def test_neighborlist_pair(): """ - This simple test test aspects of the neighborlist for 2 particles + This simple test of the neighborlist for 2 particles """ coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) @@ -151,6 +152,10 @@ def test_neighborlist_pair(): coordinates = coordinates + 0.1 assert nbr_list.check(coordinates) == True + coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + # we changed number of particles, and thus should rebuild + assert nbr_list.check(coordinates) == True + def test_inputs(): space = OrthogonalPeriodicSpace() @@ -223,7 +228,10 @@ def test_inputs(): ) -def test_neighborlist_pair2(): +def test_neighborlist_pair_multiple_particles(): + """ + Test the neighborlist for multiple particles + """ n_xyz = 2 scale_factor = 2.0 @@ -277,3 +285,107 @@ def test_neighborlist_pair2(): n_interacting, mask, dist, rij = nbr_list.calculate(coordinates) assert jnp.all(n_interacting == jnp.array([3, 2, 2, 1, 2, 1, 1, 0])) + + +def test_pairlist_pair(): + """ + This simple test of the neighborlist for 2 particles + """ + + coordinates = jnp.array([[0, 0, 0], [1, 0, 0]]) + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + state = SamplerState( + x0=unit.Quantity(coordinates, unit.nanometer), + box_vectors=unit.Quantity(box_vectors, unit.nanometer), + ) + + space = OrthogonalPeriodicSpace() + cutoff = 1.1 + skin = 0.1 + pair_list = PairList( + space, + cutoff=unit.Quantity(cutoff, unit.nanometer), + ) + + assert pair_list.cutoff == cutoff + + pair_list.build_from_state(state) + assert jnp.all(pair_list.all_pairs == jnp.array([[1], [0]], dtype=jnp.int32)) + assert jnp.all(pair_list.reduction_mask == jnp.array([[True], [False]])) + assert pair_list.is_built == True + + n_pairs, mask, dist, displacement = pair_list.calculate(coordinates) + + assert jnp.all(n_pairs == jnp.array([1, 0])) + assert jnp.all(mask == jnp.array([[1], [0]])) + assert jnp.all(dist == jnp.array([[1.0], [1.0]])) + assert displacement.shape == (2, 1, 3) + assert jnp.all(displacement == jnp.array([[[-1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0]]])) + + assert pair_list.check(coordinates) == False + + coordinates = coordinates = jnp.array([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + # we changed number of particles, and thus should rebuild + assert pair_list.check(coordinates) == True + + +def test_pair_list_multiple_particles(): + # test the pair list for multiple particles + # will compare to neighborlist + n_xyz = 2 + scale_factor = 2.0 + + coord_mesh = jnp.mgrid[0:n_xyz, 0:n_xyz, 0:n_xyz] * scale_factor / n_xyz + + # transform the mesh into a list of coordinates shape (n_atoms, 3) + coordinates = jnp.stack(coord_mesh.reshape(3, -1), axis=1, dtype=jnp.float32) + + box_vectors = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + state = SamplerState( + x0=unit.Quantity(coordinates, unit.nanometer), + box_vectors=unit.Quantity(box_vectors, unit.nanometer), + ) + + space = OrthogonalPeriodicSpace() + # every particle should interact with every other particle + cutoff = 2.1 + skin = 0.1 + pair_list = PairList( + space, + cutoff=unit.Quantity(cutoff, unit.nanometer), + ) + pair_list.build_from_state(state) + + n_interacting, mask, dist, rij = pair_list.calculate(coordinates) + assert jnp.all(n_interacting == jnp.array([7, 6, 5, 4, 3, 2, 1, 0])) + + assert jnp.all(mask.shape == (coordinates.shape[0], coordinates.shape[0] - 1)) + + # compare to nbr_list + nbr_list = NeighborListNsqrd( + space, + cutoff=unit.Quantity(cutoff, unit.nanometer), + skin=unit.Quantity(skin, unit.nanometer), + n_max_neighbors=20, + ) + nbr_list.build_from_state(state) + n_interacting1, mask1, dist1, rij1 = nbr_list.calculate(coordinates) + + # sum up all the distances within range, see if they match those in the nlist + assert jnp.where(mask, dist, 0).sum() == jnp.where(mask1, dist1, 0).sum() + + assert jnp.where( + dist + == jnp.array( + [ + [1.0, 1.0, 1.4142135, 1.0, 1.4142135, 1.4142135, 1.7320508], + [1.0, 1.4142135, 1.0, 1.4142135, 1.0, 1.7320508, 1.4142135], + [1.0, 1.4142135, 1.0, 1.4142135, 1.7320508, 1.0, 1.4142135], + [1.4142135, 1.0, 1.0, 1.7320508, 1.4142135, 1.4142135, 1.0], + [1.0, 1.4142135, 1.4142135, 1.7320508, 1.0, 1.0, 1.4142135], + [1.4142135, 1.0, 1.7320508, 1.4142135, 1.0, 1.4142135, 1.0], + [1.4142135, 1.7320508, 1.0, 1.4142135, 1.0, 1.4142135, 1.0], + [1.7320508, 1.4142135, 1.4142135, 1.0, 1.4142135, 1.0, 1.0], + ], + ) + ) From 316b1be86425d9bc5dc02a0906bbfb62cd7dbefa Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Tue, 19 Dec 2023 15:36:13 -0800 Subject: [PATCH 3/6] Added simple LJ langevin example. --- Examples/LJ_langevin.py | 82 +++++++++++++++++++++++++++++++++++++++++ chiron/potential.py | 63 +++++++++++++++++++------------ 2 files changed, 122 insertions(+), 23 deletions(-) create mode 100644 Examples/LJ_langevin.py diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py new file mode 100644 index 0000000..c4ca98a --- /dev/null +++ b/Examples/LJ_langevin.py @@ -0,0 +1,82 @@ +from openmmtools.testsystems import LennardJonesFluid + +# Use the LennardJonesFluid example from openmmtools to initialize particle positions and topology +# For this example, the topology provides the masses for the particles +# The default LennardJonesFluid example considers the system to be Argon with 39.9 amu +lj_fluid = LennardJonesFluid(reduced_density=0.1, nparticles=1000) + + +from chiron.potential import LJPotential +from openmm import unit + +# initialize the LennardJones potential in chiron +# +sigma = 0.34 * unit.nanometer +epsilon = 0.238 * unit.kilocalories_per_mole +cutoff = 3.0 * sigma + +lj_potential = LJPotential( + lj_fluid.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff +) + +from chiron.states import SamplerState, ThermodynamicState + +# define the sampler state +sampler_state = SamplerState( + x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() +) + +# define the thermodynamic state +thermodynamic_state = ThermodynamicState( + potential=lj_potential, temperature=300 * unit.kelvin +) + +from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + +# define the neighbor list for an orthogonal periodic space +skin = 0.5 * unit.nanometer + +nbr_list = NeighborListNsqrd( + OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 +) +# build the neighbor list from the sampler state +nbr_list.build_from_state(sampler_state) + +from chiron.reporters import SimulationReporter + +# initialize a reporter to save the simulation data +filename = "test_lj.h5" +import os + +if os.path.isfile(filename): + os.remove(filename) +reporter = SimulationReporter("test_lj.h5", lj_fluid.topology, 1) + +from chiron.integrators import LangevinIntegrator + +# initialize the Langevin integrator +integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) + +integrator.run( + sampler_state, + thermodynamic_state, + n_steps=2000, + nbr_list=nbr_list, + progress_bar=True, +) + +import h5py + +# read the data from the reporter +with h5py.File("test_lj.h5", "r") as f: + energies = f["energy"][:] + steps = f["step"][:] + + +# plot the energy +import matplotlib.pyplot as plt + +plt.plot(steps, energies) +plt.xlabel("Step (fs)") +plt.ylabel("Energy (kj/mol)") +plt.show() diff --git a/chiron/potential.py b/chiron/potential.py index 5d3daf5..abb3a03 100644 --- a/chiron/potential.py +++ b/chiron/potential.py @@ -25,7 +25,8 @@ def compute_force(self, positions, nbr_list=None) -> jnp.ndarray: def compute_pairlist(self, positions, cutoff) -> jnp.array: """ - Compute the pairlist for a given set of positions and a cutoff distance. + Compute the pairlist for a given set of positions and a cutoff distance + without using periodic boundary conditions or any specific optimizations. Parameters ---------- @@ -89,13 +90,21 @@ def __init__( if not isinstance(topology, Topology): if not isinstance(topology, property): if topology is not None: - raise TypeError(f"Topology must be a Topology object or None, type(topology) = {type(topology)}") + raise TypeError( + f"Topology must be a Topology object or None, type(topology) = {type(topology)}" + ) if not isinstance(sigma, unit.Quantity): - raise TypeError(f"sigma must be a unit.Quantity, type(sigma) = {type(sigma)}") + raise TypeError( + f"sigma must be a unit.Quantity, type(sigma) = {type(sigma)}" + ) if not isinstance(epsilon, unit.Quantity): - raise TypeError(f"epsilon must be a unit.Quantity, type(epsilon) = {type(epsilon)}") + raise TypeError( + f"epsilon must be a unit.Quantity, type(epsilon) = {type(epsilon)}" + ) if not isinstance(cutoff, unit.Quantity): - raise TypeError(f"cutoff must be a unit.Quantity, type(cutoff) = {type(cutoff)}") + raise TypeError( + f"cutoff must be a unit.Quantity, type(cutoff) = {type(cutoff)}" + ) if not sigma.unit.is_compatible(unit.angstrom): raise ValueError(f"sigma must have units of distance, got {sigma.unit}") @@ -104,8 +113,6 @@ def __init__( if not cutoff.unit.is_compatible(unit.nanometer): raise ValueError(f"cutoff must have units of distance, got {cutoff.unit}") - - self.sigma = sigma.value_in_unit_system( unit.md_unit_system ) # The distance at which the potential is zero @@ -141,11 +148,7 @@ def _compute_energy_masked(self, distance, mask): ) return energy.sum() - def compute_energy( - self, - positions: jnp.array, - nbr_list=None, - ): + def compute_energy(self, positions: jnp.array, nbr_list=None, debug_mode=False): """ Compute the LJ energy. @@ -153,10 +156,9 @@ def compute_energy( ---------- positions : jnp.array The positions of the particles in the system - nbr_list : NeighborList, optional - Instance of the neighborlist class to use. By default, set to None, which will use an N^2 pairlist - shift : bool, optional - Whether to shift the potential energy at the cutoff, by default False + nbr_list : NeighborList, default=None + Instance of a neighbor list or pair list class to use. + If None, an unoptimized N^2 pairlist will be used without PBC conditions. Returns ------- potential_energy : float @@ -166,6 +168,9 @@ def compute_energy( # Compute the pair distances and displacement vectors if nbr_list is None: + log.debug( + "nbr_list is None, computing pairlist using N^2 method without PBC." + ) # Compute the pairlist for a given set of positions and a cutoff distance # Note in this case, we do not need the pairs or displacement vectors # Since we already calculate the distance in the pairlist computation @@ -194,8 +199,10 @@ def compute_energy( raise ValueError("Neighborlist must be built before use") # ensure that the cutoff in the neighbor list is the same as the cutoff in the potential - if nbr_list.cutoff != self.cutoff: - raise ValueError(f"Neighborlist cutoff ({nbr_list.cutoff}) must be the same as the potential cutoff ({self.cutoff})") + if nbr_list.cutoff != self.cutoff: + raise ValueError( + f"Neighborlist cutoff ({nbr_list.cutoff}) must be the same as the potential cutoff ({self.cutoff})" + ) n_neighbors, mask, dist, displacement_vectors = nbr_list.calculate( positions @@ -273,9 +280,13 @@ def __init__( U0: unit.Quantity = 0.0 * unit.kilocalories_per_mole, ): if not isinstance(topology, Topology): - if not isinstance(topology, property): #importing from the topology from the model results in it being a property object + if not isinstance( + topology, property + ): # importing from the topology from the model results in it being a property object if topology is not None: - raise TypeError(f"Topology must be a Topology object or None, type(topology) = {type(topology)}") + raise TypeError( + f"Topology must be a Topology object or None, type(topology) = {type(topology)}" + ) if not isinstance(k, unit.Quantity): raise TypeError(f"k must be a unit.Quantity, type(k) = {type(k)}") if not isinstance(x0, unit.Quantity): @@ -284,11 +295,17 @@ def __init__( raise TypeError(f"U0 must be a unit.Quantity, type(U0) = {type(U0)}") if not k.unit.is_compatible(unit.kilocalories_per_mole / unit.angstrom**2): - raise ValueError(f"k must be a unit.Quantity with units of energy per distance squared, k.unit = {k.unit}") + raise ValueError( + f"k must be a unit.Quantity with units of energy per distance squared, k.unit = {k.unit}" + ) if not x0.unit.is_compatible(unit.angstrom): - raise ValueError(f"x0 must be a unit.Quantity with units of distance, x0.unit = {x0.unit}") + raise ValueError( + f"x0 must be a unit.Quantity with units of distance, x0.unit = {x0.unit}" + ) if not U0.unit.is_compatible(unit.kilocalories_per_mole): - raise ValueError(f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}") + raise ValueError( + f"U0 must be a unit.Quantity with units of energy, U0.unit = {U0.unit}" + ) log.info("Initializing HarmonicOscillatorPotential") log.info(f"k = {k}") From 15606c831f344e2fb18a948833f794e9ca73b3df Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 20 Dec 2023 14:39:09 +0100 Subject: [PATCH 4/6] make sure that optimization doesn't result in Nan values --- chiron/minimze.py | 6 +++- chiron/tests/test_minization.py | 59 +++++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/chiron/minimze.py b/chiron/minimze.py index a5083be..0547e81 100644 --- a/chiron/minimze.py +++ b/chiron/minimze.py @@ -1,6 +1,8 @@ import jax import jax.numpy as jnp from jaxopt import GradientDescent +from loguru import logger as log + def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): """ @@ -25,8 +27,10 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000): def objective_fn(x): if nbr_list is not None: + log.debug("Using neighbor list") return potential_fn(x, nbr_list) else: + log.debug("Using NO neighbor list") return potential_fn(x) optimizer = GradientDescent( @@ -34,4 +38,4 @@ def objective_fn(x): ) result = optimizer.run(coordinates) - return result.params + return result diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index 7b90f39..e3432ac 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -1,16 +1,15 @@ def test_minimization(): from chiron.minimze import minimize_energy - import jax import jax.numpy as jnp from chiron.states import SamplerState - from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace + from chiron.neighbors import PairList, OrthogonalPeriodicSpace from openmm import unit # initialize testystem from openmmtools.testsystems import LennardJonesFluid - lj_fluid = LennardJonesFluid(reduced_density=0.1, n_particles=100) + lj_fluid = LennardJonesFluid(reduced_density=0.1, n_particles=200) # initialize potential from chiron.potential import LJPotential @@ -20,18 +19,52 @@ def test_minimization(): sampler_state = SamplerState( lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors() ) - skin = unit.Quantity(0.1, unit.nanometer) + # use parilist + nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + nbr_list.build_from_state(sampler_state) + + # compute intial energy with and without pairlist + initial_e_with_nbr_list = lj_potential.compute_energy(sampler_state.x0, nbr_list) + initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.x0) + print(f"initial_e_with_nbr_list: {initial_e_with_nbr_list}") + print(f"initial_e_without_nbr_list: {initial_e_without_nbr_list}") - nbr_list = NeighborListNsqrd( - OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 + # minimize energy for 0 steps + results = minimize_energy( + sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 + ) + + # check that the minimization did not change the energy + min_x = results.params + # after 0 steps of minimization + after_0_steps_minimization_e_with_nbr_list = lj_potential.compute_energy( + min_x, nbr_list + ) + after_0_steps_minimization_e_without_nbr_list = lj_potential.compute_energy( + sampler_state.x0 + ) + print( + f"after_0_steps_minimization_e_with_nbr_list: {after_0_steps_minimization_e_with_nbr_list}" + ) + print( + f"after_0_steps_minimization_e_without_nbr_list: {after_0_steps_minimization_e_without_nbr_list}" + ) + assert jnp.isclose( + initial_e_with_nbr_list, after_0_steps_minimization_e_with_nbr_list ) - nbr_list.build_from_state(sampler_state) - print(lj_potential.compute_energy(sampler_state.x0, nbr_list)) - print(lj_potential.compute_energy(sampler_state.x0)) + assert jnp.isclose( + initial_e_without_nbr_list, after_0_steps_minimization_e_without_nbr_list + ) - min_x = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=10_000 + # after 100 steps of minimization + results = minimize_energy( + sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=100 ) - e = lj_potential.compute_energy(min_x, nbr_list) - assert jnp.isclose(e, -13332.688, atol=1) + min_x = results.params + e_min = lj_potential.compute_energy(min_x, nbr_list) + print(f"e_min: {e_min}") + # test that e_min is smaller than initial_e_with_nbr_list + assert e_min < initial_e_with_nbr_list + # test that e is not Nan + assert not jnp.isnan(lj_potential.compute_energy(min_x, nbr_list)) From 254d485fcbeb33f3afe7b05365eaaf20dc162e77 Mon Sep 17 00:00:00 2001 From: wiederm Date: Wed, 20 Dec 2023 18:19:19 +0100 Subject: [PATCH 5/6] small update to min test --- chiron/tests/test_minization.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index e3432ac..4308e6d 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -28,7 +28,9 @@ def test_minimization(): initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.x0) print(f"initial_e_with_nbr_list: {initial_e_with_nbr_list}") print(f"initial_e_without_nbr_list: {initial_e_without_nbr_list}") - + assert not jnp.isclose( + initial_e_with_nbr_list, initial_e_without_nbr_list + ), "initial_e_with_nbr_list and initial_e_without_nbr_list should not be close" # minimize energy for 0 steps results = minimize_energy( sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0 @@ -58,13 +60,14 @@ def test_minimization(): ) # after 100 steps of minimization + steps = 100 results = minimize_energy( - sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=100 + sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=steps ) min_x = results.params e_min = lj_potential.compute_energy(min_x, nbr_list) - print(f"e_min: {e_min}") + print(f"e_min after {steps} of minimization: {e_min}") # test that e_min is smaller than initial_e_with_nbr_list assert e_min < initial_e_with_nbr_list # test that e is not Nan - assert not jnp.isnan(lj_potential.compute_energy(min_x, nbr_list)) + assert not jnp.isnan(lj_potential.compute_energy(min_x, nbr_list)) \ No newline at end of file From 97a629f32950d64ab8a28fc57b497fb46f38dfba Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Wed, 20 Dec 2023 09:30:10 -0800 Subject: [PATCH 6/6] Added simple LJ langevin example. --- Examples/LJ_langevin.py | 9 ++++++- chiron/neighbors.py | 4 ++- chiron/tests/test_minization.py | 44 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/Examples/LJ_langevin.py b/Examples/LJ_langevin.py index c4ca98a..93f4415 100644 --- a/Examples/LJ_langevin.py +++ b/Examples/LJ_langevin.py @@ -39,6 +39,9 @@ nbr_list = NeighborListNsqrd( OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) +from chiron.neighbors import PairList + + # build the neighbor list from the sampler state nbr_list.build_from_state(sampler_state) @@ -56,11 +59,12 @@ # initialize the Langevin integrator integrator = LangevinIntegrator(reporter=reporter, save_frequency=100) +print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list)) integrator.run( sampler_state, thermodynamic_state, - n_steps=2000, + n_steps=5000, nbr_list=nbr_list, progress_bar=True, ) @@ -80,3 +84,6 @@ plt.xlabel("Step (fs)") plt.ylabel("Energy (kj/mol)") plt.show() + +print(energies) +print(steps) diff --git a/chiron/neighbors.py b/chiron/neighbors.py index 995b795..a8c7bbe 100644 --- a/chiron/neighbors.py +++ b/chiron/neighbors.py @@ -897,9 +897,11 @@ def calculate(self, coordinates: jnp.array): if coordinates.shape[0] != self.n_particles: raise ValueError( f"Number of particles cannot changes without rebuilding. " - f"Coordinates must have shape ({n_particles}, 3), found {coordinates.shape}" + f"Coordinates must have shape ({self.n_particles}, 3), found {coordinates.shape}" ) + # coordinates = self.space.wrap(coordinates) + n_neighbors, padding_mask, dist, r_ij = jax.vmap( self._calc_distance_per_particle, in_axes=(0, 0, 0, None) )(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates) diff --git a/chiron/tests/test_minization.py b/chiron/tests/test_minization.py index e3432ac..42328e3 100644 --- a/chiron/tests/test_minization.py +++ b/chiron/tests/test_minization.py @@ -68,3 +68,47 @@ def test_minimization(): assert e_min < initial_e_with_nbr_list # test that e is not Nan assert not jnp.isnan(lj_potential.compute_energy(min_x, nbr_list)) + + +def test_minimize_two_particles(): + # this test will check to see if we can minimize the energy of two particles + # to the minimum of the LJ potential + + from chiron.minimze import minimize_energy + import jax.numpy as jnp + + from chiron.states import SamplerState + from chiron.neighbors import PairList, OrthogonalPeriodicSpace + from openmm import unit + from chiron.potential import LJPotential + + sigma = 1.0 * unit.nanometer + epsilon = 1.0 * unit.kilojoules_per_mole + cutoff = 3.0 * sigma + + lj_potential = LJPotential(None, sigma=sigma, epsilon=epsilon, cutoff=cutoff) + + coordinates = jnp.array([[0.0, 0.0, 0.0], [0.9, 0.0, 0.0]]) + + # define the sampler state + sampler_state = SamplerState( + x0=coordinates * unit.nanometer, + box_vectors=jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + * unit.nanometer, + ) + + pair_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff) + pair_list.build_from_state(sampler_state) + + e_start = lj_potential.compute_energy(coordinates, pair_list) + + min_x = minimize_energy( + coordinates, lj_potential.compute_energy, pair_list, maxiter=10_000 + ) + min_x = min_x.params + dist = jnp.linalg.norm(min_x[1] - min_x[0]) + + e_final = lj_potential.compute_energy(min_x, pair_list) + + assert jnp.isclose(e_final, -1.0, atol=1e-3) + assert jnp.isclose(dist, 2 ** (1.0 / 6.0), atol=1e-3)