Skip to content

Commit

Permalink
Merge branch 'main' into multistage
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm authored Dec 21, 2023
2 parents 6af6ef6 + 46166ef commit 7b9206c
Show file tree
Hide file tree
Showing 6 changed files with 405 additions and 85 deletions.
89 changes: 89 additions & 0 deletions Examples/LJ_langevin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from openmmtools.testsystems import LennardJonesFluid

# Use the LennardJonesFluid example from openmmtools to initialize particle positions and topology
# For this example, the topology provides the masses for the particles
# The default LennardJonesFluid example considers the system to be Argon with 39.9 amu
lj_fluid = LennardJonesFluid(reduced_density=0.1, nparticles=1000)


from chiron.potential import LJPotential
from openmm import unit

# initialize the LennardJones potential in chiron
#
sigma = 0.34 * unit.nanometer
epsilon = 0.238 * unit.kilocalories_per_mole
cutoff = 3.0 * sigma

lj_potential = LJPotential(
lj_fluid.topology, sigma=sigma, epsilon=epsilon, cutoff=cutoff
)

from chiron.states import SamplerState, ThermodynamicState

# define the sampler state
sampler_state = SamplerState(
x0=lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors()
)

# define the thermodynamic state
thermodynamic_state = ThermodynamicState(
potential=lj_potential, temperature=300 * unit.kelvin
)

from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace

# define the neighbor list for an orthogonal periodic space
skin = 0.5 * unit.nanometer

nbr_list = NeighborListNsqrd(
OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180
)
from chiron.neighbors import PairList


# build the neighbor list from the sampler state
nbr_list.build_from_state(sampler_state)

from chiron.reporters import SimulationReporter

# initialize a reporter to save the simulation data
filename = "test_lj.h5"
import os

if os.path.isfile(filename):
os.remove(filename)
reporter = SimulationReporter("test_lj.h5", lj_fluid.topology, 1)

from chiron.integrators import LangevinIntegrator

# initialize the Langevin integrator
integrator = LangevinIntegrator(reporter=reporter, save_frequency=100)
print("init_energy: ", lj_potential.compute_energy(sampler_state.x0, nbr_list))

integrator.run(
sampler_state,
thermodynamic_state,
n_steps=5000,
nbr_list=nbr_list,
progress_bar=True,
)

import h5py

# read the data from the reporter
with h5py.File("test_lj.h5", "r") as f:
energies = f["energy"][:]
steps = f["step"][:]


# plot the energy
import matplotlib.pyplot as plt

plt.plot(steps, energies)
plt.xlabel("Step (fs)")
plt.ylabel("Energy (kj/mol)")
plt.show()

print(energies)
print(steps)
6 changes: 5 additions & 1 deletion chiron/minimze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jax
import jax.numpy as jnp
from jaxopt import GradientDescent
from loguru import logger as log


def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):
"""
Expand All @@ -25,13 +27,15 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):

def objective_fn(x):
if nbr_list is not None:
log.debug("Using neighbor list")
return potential_fn(x, nbr_list)
else:
log.debug("Using NO neighbor list")
return potential_fn(x)

optimizer = GradientDescent(
fun=jax.value_and_grad(objective_fn), value_and_grad=True, maxiter=maxiter
)
result = optimizer.run(coordinates)

return result.params
return result
109 changes: 62 additions & 47 deletions chiron/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def check(self, coordinates: jnp.array) -> bool:
"""
Check if the neighbor list needs to be rebuilt based on displacement of the particles from the reference coordinates.
If a particle moves more than 0.5 skin distance, the neighborlist will be rebuilt.
Will also return True if the size of the coordinates array changes.
Note, this could also accept a user defined criteria for distance, but this is not implemented yet.
Expand All @@ -653,6 +654,10 @@ def check(self, coordinates: jnp.array) -> bool:
bool
True if the neighbor list needs to be rebuilt, False if it does not.
"""

if self.ref_coordinates.shape[0] != coordinates.shape[0]:
return True

status = jax.vmap(
self._calculate_particle_displacement, in_axes=(0, None, None)
)(self.particle_ids, coordinates, self.ref_coordinates)
Expand All @@ -673,38 +678,39 @@ class PairList(PairsBase):
space: Space
Class that defines how to calculate the displacement between two points and apply the boundary conditions
cutoff: float, default = 2.5
Cutoff distance for the neighborlist
n_max_neighbors: int, default=200
Maximum number of neighbors for each particle. Used for padding arrays for efficient jax computations
This will be checked and dynamically updated during the build stage
Cutoff distance for the pair list calculation
Examples
--------
>>> from chiron.neighbors import PairList, OrthogonalPeriodicSpace
>>> from chiron.states import SamplerState
>>> import jax.numpy as jnp
>>>
>>> space = OrthogonalPeriodicSpace()
>>> pair_list = PairList(space, cutoff=2.5)
>>> sampler_state = SamplerState(x0=jnp.array([[0.0, 0.0, 0.0], [2, 0.0, 0.0], [0.0, 2, 0.0]]),
>>> box_vectors=jnp.array([[10, 0.0, 0.0], [0.0, 10, 0.0], [0.0, 0.0, 10]]))
>>> pair_list.build_from_state(sampler_state)
>>>
>>> # mask and distances are of shape (n_particles, n_particles-1),
>>> displacement_vectors of shape (n_particles, n_particles-1, 3)
>>> # mask, is a bool array that is True if the particle is within the cutoff distance, False if it is not
>>> # n_pairs is of shape (n_particles) and is per row sum of the mask. The mask ensure we also do not double count pairs
>>> n_pairs, mask, distances, displacement_vectors = pair_list.calculate(sampler_state.x0)
"""

def __init__(
self,
space: Space,
cutoff: unit.Quantity = unit.Quantity(1.2, unit.nanometer),
skin: unit.Quantity = unit.Quantity(0.4, unit.nanometer),
n_max_neighbors: float = 200,
):
if not isinstance(space, Space):
raise TypeError(f"space must be of type Space, found {type(space)}")
if not cutoff.unit.is_compatible(unit.angstrom):
raise ValueError(
f"cutoff must be a unit.Quantity with units of distance, cutoff.unit = {cutoff.unit}"
)
if not skin.unit.is_compatible(unit.angstrom):
raise ValueError(
f"cutoff must be a unit.Quantity with units of distance, skin.unit = {skin.unit}"
)

self.cutoff = cutoff.value_in_unit_system(unit.md_unit_system)
self.skin = skin.value_in_unit_system(unit.md_unit_system)
self.cutoff_and_skin = self.cutoff + self.skin
self.n_max_neighbors = n_max_neighbors
self.space = space

# set a a simple variable to know if this has at least been built once as opposed to just initialized
Expand All @@ -716,7 +722,8 @@ def __init__(
@partial(jax.jit, static_argnums=(0,))
def _pairs_and_mask(self, particle_ids: jnp.array):
"""
Jitted function to generate mask that allows us to remove self-interactions and double-counting of pairs
Jitted function to generate all pairs (excluding self interactions)
and mask that allows us to remove double-counting of pairs.
Parameters
----------
Expand All @@ -725,8 +732,10 @@ def _pairs_and_mask(self, particle_ids: jnp.array):
Returns
-------
jnp.array
Bool mask to remove self-interactions and double-counting of pairs
all_pairs: jnp.array
Array of all pairs (excluding self interactions), of size (n_particles, n_particles-1)
reduction_mask: jnp.array
Bool mask that identifies which pairs to exclude to remove double counting of pairs
"""
# for the nsq approach, we consider the distance between a particle and all other particles in the system
Expand All @@ -748,9 +757,9 @@ def _pairs_and_mask(self, particle_ids: jnp.array):
particles_j, temp_mask
)

temp_mask = jnp.where(particles_i < all_pairs[0], True, False)
reduction_mask = jnp.where(particles_i < all_pairs[0], True, False)

return all_pairs[0], temp_mask
return all_pairs[0], reduction_mask

@partial(jax.jit, static_argnums=(0,))
def _remove_self_interactions(self, particles, temp_mask):
Expand All @@ -769,7 +778,7 @@ def build(
Parameters
----------
coordinates: jnp.array
Shape[N,3] array of particle coordinates
Shape[n_particles,3] array of particle coordinates
box_vectors: jnp.array
Shape[3,3] array of box vectors
Expand Down Expand Up @@ -800,23 +809,18 @@ def build(
f"box_vectors should be a 3x3 array, shape provided: {box_vectors.shape}"
)

self.ref_coordinates = coordinates
self.n_particles = coordinates.shape[0]
self.box_vectors = box_vectors

# the neighborlist assumes that the box vectors do not change between building and calculating the neighbor list
# changes to the box vectors require rebuilding the neighbor list
self.space.box_vectors = self.box_vectors

# store the ids of all the particles
self.particle_ids = jnp.array(
range(0, self.ref_coordinates.shape[0]), dtype=jnp.uint16
)
self.particle_ids = jnp.array(range(0, coordinates.shape[0]), dtype=jnp.uint16)

# calculate which pairs to exclude
self.all_pairs, self.reduction_mask = self._pairs_and_mask(self.particle_ids)
self.n_max_neighbors = self.particle_ids.shape[0] - 1

# temp_mask = particles_i < new_particles

self.is_built = True

Expand All @@ -825,30 +829,34 @@ def _calc_distance_per_particle(
self, particle1, neighbors, neighbor_mask, coordinates
):
"""
Jitted function to calculate the distance between a particle and its neighbors
Jitted function to calculate the distance between a particle and all possible neighbors
Parameters
----------
particle1: int
Particle id
neighbors: jnp.array
Array of particle ids for the neighbors of particle1
Array of particle ids for the possible particle pairs of particle1
neighbor_mask: jnp.array
Mask to exclude padding from the neighbor list of particle1
Mask to exclude double particles to prevent double counting
coordinates: jnp.array
X,Y,Z coordinates of all particles
X,Y,Z coordinates of all particles, shaped (n_particles, 3)
Returns
-------
n_pairs: int
Number of interacting pairs for the particle
mask: jnp.array
Mask to exclude padding from the neighbor list of particle1.
Mask to exclude padding particles not within the cutoff particle1.
If a particle is within the interaction cutoff, the mask is 1, otherwise it is 0
Array has shape (n_particles, n_particles-1) as it excludes self interactions
dist: jnp.array
Array of distances between the particle and its neighbors
Array of distances between the particle and all other particles in the system.
Array has shape (n_particles, n_particles-1) as it excludes self interactions
r_ij: jnp.array
Array of displacement vectors between the particle and its neighbors
Array of displacement vectors between the particle and all other particles in the system.
Array has shape (n_particles, n_particles-1, 3) as it excludes self interactions
"""
# repeat the particle id for each neighbor
particles1 = jnp.repeat(particle1, neighbors.shape[0])
Expand All @@ -873,33 +881,37 @@ def calculate(self, coordinates: jnp.array):
Parameters
----------
coordinates: jnp.array
Shape[N,3] array of particle coordinates
Shape[n_particles,3] array of particle coordinates
Returns
-------
n_neighbors: jnp.array
Array of number of neighbors for each particle
Array of number of interacting particles for each particle
padding_mask: jnp.array
Array of masks to exclude padding from the neighbor list of each particle
Array used to masks non interaction particle pairs,
dist: jnp.array
Array of distances between each particle and its neighbors
Array of distances between each particle and all other particles in the system
r_ij: jnp.array
Array of displacement vectors between each particle and its neighbors
Array of displacement vectors between each particle and all other particles in the system.
"""
# coordinates = sampler_state.x0
# note, we assume the box vectors do not change between building and calculating the neighbor list
# changes to the box vectors require rebuilding the neighbor list
if coordinates.shape[0] != self.n_particles:
raise ValueError(
f"Number of particles cannot changes without rebuilding. "
f"Coordinates must have shape ({self.n_particles}, 3), found {coordinates.shape}"
)

# coordinates = self.space.wrap(coordinates)

n_neighbors, padding_mask, dist, r_ij = jax.vmap(
self._calc_distance_per_particle, in_axes=(0, 0, 0, None)
)(self.particle_ids, self.all_pairs, self.reduction_mask, coordinates)
# mask = mask.reshape(-1, self.n_max_neighbors)

return n_neighbors, padding_mask, dist, r_ij

def check(self, coordinates: jnp.array) -> bool:
"""
Always returns false, as there are no internal lists to be rebuilt for a pairlist
Check if we need to reconstruct internal arrays.
For a simple pairlist this will always return False, unless the number of particles change.
Parameters
----------
Expand All @@ -908,6 +920,9 @@ def check(self, coordinates: jnp.array) -> bool:
Returns
-------
bool
True if the neighbor list needs to be rebuilt, False if it does not.
True if we need to rebuild the neighbor list, False if we do not.
"""
return False
if coordinates.shape[0] != self.n_particles:
return True
else:
return False
Loading

0 comments on commit 7b9206c

Please sign in to comment.