Skip to content

Commit

Permalink
make sure that optimization doesn't result in Nan values
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Dec 20, 2023
1 parent 316b1be commit 15606c8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
6 changes: 5 additions & 1 deletion chiron/minimze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import jax
import jax.numpy as jnp
from jaxopt import GradientDescent
from loguru import logger as log


def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):
"""
Expand All @@ -25,13 +27,15 @@ def minimize_energy(coordinates, potential_fn, nbr_list=None, maxiter=1000):

def objective_fn(x):
if nbr_list is not None:
log.debug("Using neighbor list")
return potential_fn(x, nbr_list)
else:
log.debug("Using NO neighbor list")
return potential_fn(x)

optimizer = GradientDescent(
fun=jax.value_and_grad(objective_fn), value_and_grad=True, maxiter=maxiter
)
result = optimizer.run(coordinates)

return result.params
return result
59 changes: 46 additions & 13 deletions chiron/tests/test_minization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
def test_minimization():
from chiron.minimze import minimize_energy
import jax
import jax.numpy as jnp

from chiron.states import SamplerState
from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace
from chiron.neighbors import PairList, OrthogonalPeriodicSpace
from openmm import unit

# initialize testystem
from openmmtools.testsystems import LennardJonesFluid

lj_fluid = LennardJonesFluid(reduced_density=0.1, n_particles=100)
lj_fluid = LennardJonesFluid(reduced_density=0.1, n_particles=200)
# initialize potential
from chiron.potential import LJPotential

Expand All @@ -20,18 +19,52 @@ def test_minimization():
sampler_state = SamplerState(
lj_fluid.positions, box_vectors=lj_fluid.system.getDefaultPeriodicBoxVectors()
)
skin = unit.Quantity(0.1, unit.nanometer)
# use parilist
nbr_list = PairList(OrthogonalPeriodicSpace(), cutoff=cutoff)
nbr_list.build_from_state(sampler_state)

# compute intial energy with and without pairlist
initial_e_with_nbr_list = lj_potential.compute_energy(sampler_state.x0, nbr_list)
initial_e_without_nbr_list = lj_potential.compute_energy(sampler_state.x0)
print(f"initial_e_with_nbr_list: {initial_e_with_nbr_list}")
print(f"initial_e_without_nbr_list: {initial_e_without_nbr_list}")

nbr_list = NeighborListNsqrd(
OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180
# minimize energy for 0 steps
results = minimize_energy(
sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=0
)

# check that the minimization did not change the energy
min_x = results.params
# after 0 steps of minimization
after_0_steps_minimization_e_with_nbr_list = lj_potential.compute_energy(
min_x, nbr_list
)
after_0_steps_minimization_e_without_nbr_list = lj_potential.compute_energy(
sampler_state.x0
)
print(
f"after_0_steps_minimization_e_with_nbr_list: {after_0_steps_minimization_e_with_nbr_list}"
)
print(
f"after_0_steps_minimization_e_without_nbr_list: {after_0_steps_minimization_e_without_nbr_list}"
)
assert jnp.isclose(
initial_e_with_nbr_list, after_0_steps_minimization_e_with_nbr_list
)
nbr_list.build_from_state(sampler_state)

print(lj_potential.compute_energy(sampler_state.x0, nbr_list))
print(lj_potential.compute_energy(sampler_state.x0))
assert jnp.isclose(
initial_e_without_nbr_list, after_0_steps_minimization_e_without_nbr_list
)

min_x = minimize_energy(
sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=10_000
# after 100 steps of minimization
results = minimize_energy(
sampler_state.x0, lj_potential.compute_energy, nbr_list, maxiter=100
)
e = lj_potential.compute_energy(min_x, nbr_list)
assert jnp.isclose(e, -13332.688, atol=1)
min_x = results.params
e_min = lj_potential.compute_energy(min_x, nbr_list)
print(f"e_min: {e_min}")
# test that e_min is smaller than initial_e_with_nbr_list
assert e_min < initial_e_with_nbr_list
# test that e is not Nan
assert not jnp.isnan(lj_potential.compute_energy(min_x, nbr_list))

0 comments on commit 15606c8

Please sign in to comment.