diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py
index b263d2b5f..dd876db4b 100644
--- a/dmff/admp/qeq.py
+++ b/dmff/admp/qeq.py
@@ -1,7 +1,7 @@
import numpy as np
import jax.numpy as jnp
from ..common.constants import DIELECTRIC
-from jax import grad, vmap
+from jax import grad, value_and_grad, vmap, jacfwd, jacrev
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
from ..settings import PRECISION
@@ -69,8 +69,9 @@ def padding_consts(const_list, max_idx):
@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
- constraint = (group_sum(q, const_list) - const_vals) * lagmt
- return jnp.sum(constraint)
+ # constraint = (group_sum(q, const_list) - const_vals) * lagmt
+ # return jnp.sum(constraint)
+ return 0.0
@jit_condition()
@@ -78,6 +79,9 @@ def E_constP(q, lagmt, const_list, const_vals):
constraint = group_sum(q, const_list) * const_vals
return jnp.sum(constraint)
+@jit_condition()
+def E_noconst(q, lagmt, const_list, const_vals):
+ return 0.0
@vmap
@jit_condition()
@@ -87,13 +91,14 @@ def mask_to_zero(v, mask):
)
-@jit_condition()
-def E_sr(pos, box, pairs, q, eta, ds, buffer_scales):
+@jit_condition(static_argnums=[6])
+def E_sr(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
return 0.0
-@jit_condition()
-def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
+@jit_condition(static_argnums=[6])
+def E_sr2(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
+ ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2))
pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
@@ -104,8 +109,9 @@ def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
return e_sr
-@jit_condition()
-def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
+@jit_condition(static_argnums=[6])
+def E_sr3(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
+ ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(
eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2 + 1e-64
) # add eta to avoid division by zero
@@ -126,13 +132,13 @@ def E_site(chi, J, q):
@jit_condition()
def E_site2(chi, J, q):
- ene = (chi * q + 0.5 * J * q**2) * 96.4869
+ ene = (chi * q + 0.5 * J * q**2) * 96.4869 #ev to kj/mol
return jnp.sum(ene)
@jit_condition()
def E_site3(chi, J, q):
- ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
+ ene = chi * q + J* q**2 # kj/mol
return jnp.sum(ene)
@@ -149,7 +155,7 @@ def E_corr(pos, box, pairs, q, kappa, neutral_flag=True):
- Q_tot * (jnp.sum(q * pos[:, 2] ** 2))
- jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12
)
- if neutral_flag:
+ if not neutral_flag:
# kappa = pme_potential.pme_force.kappa
pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC
e_corr_non = pre_corr_non * Q_tot**2
@@ -211,20 +217,47 @@ def __init__(
slab_flag: bool = False,
constQ: bool = True,
pbc_flag: bool = True,
+ part_const:bool = True,
has_aux=False,
):
self.has_aux = has_aux
+ self.part_const = part_const
const_vals = np.array(const_vals)
- if neutral_flag:
- const_vals = const_vals - np.sum(const_vals) / len(const_vals)
+ #if neutral_flag:
+ # const_vals = const_vals - np.sum(const_vals) / len(const_vals)
self.const_vals = jnp.array(const_vals)
assert len(const_list) == len(
const_vals
), "const_list and const_vals must have the same length"
n_atoms = len(init_q)
- self.const_list = padding_consts(const_list, n_atoms)
+
+ const_mat = np.zeros((len(const_list), n_atoms))
+ for ncl, cl in enumerate(const_list):
+ const_mat[ncl][cl] = 1
+ self.const_mat = jnp.array(const_mat)
+
+ if len(const_list) != 0:
+ self.const_list = padding_consts(const_list, n_atoms)
+ #if fix part charges
+ self.all_const_list = self.const_list[jnp.where(self.const_list < n_atoms)]
+ else:
+ self.const_list = np.array(const_list)
+ self.all_const_list = self.const_list
+
+ all_fix_list = jnp.setdiff1d(jnp.array(range(n_atoms)),self.all_const_list)
+ fix_mat = np.zeros((len(all_fix_list),n_atoms))
+ for i, j in enumerate(all_fix_list):
+ fix_mat[i][j] = 1
+ self.all_fix_list = jnp.array(all_fix_list)
+ self.fix_mat = jnp.array(fix_mat)
+
self.init_q = jnp.array(init_q)
self.init_lagmt = jnp.ones((len(const_list),))
+
+ self.init_energy = True #init charge by hession inversion method
+ self.icount = 0
+ self.hessinv_stride = 1
+ self.qupdate_stride = 1
self.damp_mod = damp_mod
self.neutral_flag = neutral_flag
@@ -234,6 +267,8 @@ def __init__(
if constQ:
e_constraint = E_constQ
+ elif not part_const:
+ e_constraint = E_noconst
else:
e_constraint = E_constP
self.e_constraint = e_constraint
@@ -318,11 +353,15 @@ def coul_energy(positions, box, pairs, q, mscales):
self.coul_energy = coul_energy
+
def generate_get_energy(self):
@jit_condition()
- def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
- e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
- e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
+ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
+ if self.part_const:
+ e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
+ else:
+ e1 = 0
+ e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, buffer_scales, self.pbc_flag)
e3 = self.e_site(chi, J, q)
e4 = self.coul_energy(pos, box, pairs, q, mscales)
if self.slab_flag:
@@ -336,70 +375,195 @@ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
grad_E_full = grad(E_full, argnums=(0, 1))
@jit_condition()
- def E_grads(
- b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
- ):
- n_const = len(self.const_vals)
- q = b_value[:-n_const]
- lagmt = b_value[-n_const:]
-
- g1, g2 = grad_E_full(
- q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
- )
- g = jnp.concatenate((g1, g2))
- return g
+ def E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
+ h = jacfwd(jacrev(E_full, argnums=(0)))(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)
+ return h
- def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
+ @jit_condition()
+ def get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
pos = positions
- ds = ds_pairs(pos, box, pairs, self.pbc_flag)
buffer_scales = pair_buffer_scales(pairs)
n_const = len(self.init_lagmt)
+ b_vector = jnp.concatenate((-chi, self.const_vals)) #For E_site3
+
if self.has_aux:
- b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
+ q = aux["q"][:len(pos)]
+ lagmt = aux["q"][len(pos):]
else:
- b_value = jnp.concatenate([self.init_q, self.init_lagmt])
- # if JAXOPT_OLD:
- if True:
- rf = jaxopt.ScipyRootFinding(
- optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
- )
+ q = self.init_q
+ lagmt = self.init_lagmt
+ B = E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)
+
+ if self.part_const:
+ C = jnp.eye(len(q))
+ A = C.at[self.all_const_list].set(B[self.all_const_list])
else:
- rf = jaxopt.Broyden(fun=E_grads, tol=1e-10)
- b_0, _ = rf.run(
- b_value,
+ A = self.fix_mat
+
+ b_vector = b_vector.at[self.all_fix_list].set(q[self.all_fix_list])
+
+
+ m0 = jnp.concatenate((A,self.const_mat),axis=0)
+ n0 = jnp.concatenate((jnp.transpose(self.const_mat),jnp.zeros((n_const,n_const))),axis=0)
+
+ M = jnp.concatenate((m0,n0),axis=1)
+
+ q_0 = jnp.linalg.solve(M,b_vector)
+ q = q_0[:len(pos)]
+ lagmt = q_0[len(pos):]
+ energy = E_full(
+ q,
+ lagmt,
chi,
J,
positions,
box,
pairs,
eta,
- ds,
buffer_scales,
mscales,
)
- b_0 = jax.lax.stop_gradient(b_0)
- q_0 = b_0[:-n_const]
- lagmt_0 = b_0[-n_const:]
+ self.init_energy = False
+ # self.icount = self.icount + 1
+ if self.has_aux:
+ aux["q"] = q_0
+ aux["A"] = A
+ aux["m0"] = m0
+ aux["n0"] = n0
+ aux["b_vector"] = b_vector
+ aux["init_energy"] = self.init_energy
+ aux["icount"] = self.icount
+ return energy, aux
+ else:
+ return energy
+
+ # @jit_condition()
+ def get_proj_grad(func, constraint_matrix, has_aux=False):
+ def value_and_proj_grad(*arg, **kwargs):
+ value, grad = value_and_grad(func, has_aux=has_aux)(*arg, **kwargs)
+ a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1))
+ b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True)
+ delta_grad = jnp.matmul((a / b).T, constraint_matrix)
+ proj_grad = grad - delta_grad.reshape(-1)
+ return value, proj_grad
+ return value_and_proj_grad
+
+ @jit_condition()
+ def get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
+ if self.init_energy:
+ if self.has_aux:
+ energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
+ return energy, aux
+ else:
+ energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
+ return energy
+ if not self.icount % self.hessinv_stride :
+ if self.has_aux:
+ energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
+ return energy, aux
+ else:
+ energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
+ return energy
+
+ func = get_proj_grad(E_full,self.const_mat)
+ solver = jaxopt.LBFGS(
+ fun=func,
+ value_and_grad=True,
+ tol=1e-2,
+ )
+ pos = positions
+ buffer_scales = pair_buffer_scales(pairs)
+ if self.has_aux:
+ q = aux["q"][:len(pos)]
+ lagmt = aux["q"][len(pos):]
+ else:
+ q = self.init_q
+ lagmt = self.init_lagmt
+
+ res = solver.run(
+ q,
+ lagmt,
+ chi,
+ J,
+ positions,
+ box,
+ pairs,
+ eta,
+ buffer_scales,
+ mscales,
+ )
+ q_opt = res.params
energy = E_full(
- q_0,
- lagmt_0,
+ q_opt,
+ lagmt,
chi,
J,
positions,
box,
pairs,
eta,
- ds,
buffer_scales,
mscales,
)
if self.has_aux:
- aux["q"] = q_0
- aux["lagmt"] = lagmt_0
+ aux["q"] = aux['q'].at[:len(pos)].set(q_opt)
return energy, aux
else:
return energy
+ # @jit_condition()
+ def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
+ if self.has_aux :
+ if "const_vals" in aux.keys():
+ self.const_vals = aux["const_vals"]
+ if "hessinv_stride" in aux.keys():
+ self.hessinv_stride = aux["hessinv_stride"]
+ if "qupdate_stride" in aux.keys():
+ self.qupdate_stride = aux["qupdate_stride"]
+ if not self.icount % self.qupdate_stride :
+ if self.has_aux:
+ # aux["q"] = aux['q'].at[:len(pos)].set(q)
+ energy, aux = get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux)
+ self.icount = self.icount + 1
+ aux["icount"] = self.icount
+ return energy, aux
+ else:
+ self.icount = self.icount + 1
+ energy = get_step_energy(positions, box, pairs, mscales, eta, chi, J )
+ return energy
+ else:
+ self.icount = self.icount + 1
+ # print(self.icount)
+ pos = positions
+ buffer_scales = pair_buffer_scales(pairs)
+ if self.has_aux:
+ q = aux["q"][:len(pos)]
+ lagmt = aux["q"][len(pos):]
+ else:
+ q = self.init_q
+ lagmt = self.init_lagmt
+ energy = E_full(
+ q,
+ lagmt,
+ chi,
+ J,
+ positions,
+ box,
+ pairs,
+ eta,
+ buffer_scales,
+ mscales,
+ )
+
+ if self.has_aux:
+ aux = aux
+ # aux["q"] = aux['q'].at[:len(pos)].set(q)
+ aux["icount"] = self.icount
+ return energy, aux
+ else:
+ return energy
+
return get_energy
+
diff --git a/dmff/generators/qeq.py b/dmff/generators/qeq.py
index d8e599ce2..5faaa1512 100644
--- a/dmff/generators/qeq.py
+++ b/dmff/generators/qeq.py
@@ -148,6 +148,7 @@ def createPotential(
neutral_flag = kwargs.get("neutral", True)
slab_flag = kwargs.get("slab", False)
constQ = kwargs.get("constQ", True)
+ part_const = kwargs.get("part_const", True)
# top info
n_atoms = topdata.getNumAtoms()
@@ -188,6 +189,7 @@ def createPotential(
slab_flag=slab_flag,
constQ=constQ,
pbc_flag=(not isNoCut),
+ part_const=part_const,
has_aux=has_aux,
)
qeq_energy = qeq_force.generate_get_energy()
@@ -206,13 +208,12 @@ def potential_fn(
eta = params[self.name]["eta"][map_idx]
chi = params[self.name]["chi"][map_idx]
J = params[self.name]["J"][map_idx]
-
if has_aux:
qeq_energy0, aux = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J, aux)
# return pme_energy + qeq_energy0
return qeq_energy0, aux
else:
- qeq_energy0 = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J)
+ qeq_energy0 = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J )
# return pme_energy + qeq_energy0
return qeq_energy0
diff --git a/tests/data/qeq.xml b/tests/data/qeq.xml
index a017ea93f..120312626 100644
--- a/tests/data/qeq.xml
+++ b/tests/data/qeq.xml
@@ -156,10 +156,10 @@
-
-
-
-
-
+
+
+
+
+
diff --git a/tests/data/qeq2.xml b/tests/data/qeq2.xml
index e8e79bec1..d93e876d2 100644
--- a/tests/data/qeq2.xml
+++ b/tests/data/qeq2.xml
@@ -163,12 +163,12 @@
-
-
-
-
-
-
-
+
+
+
+
+
+
+
diff --git a/tests/test_admp/test_qeq.py b/tests/test_admp/test_qeq.py
index 296801eba..54240ab95 100644
--- a/tests/test_admp/test_qeq.py
+++ b/tests/test_admp/test_qeq.py
@@ -10,6 +10,7 @@
def test_qeq_energy():
+ rc = 0.8
xml = XMLIO()
xml.loadXML("tests/data/qeq.xml")
res = xml.parseResidues()
@@ -28,19 +29,19 @@ def test_qeq_energy():
a.meta["charge"] = charges[na]
a.meta["type"] = types[na]
- nblist = NeighborList(box, 0.6, dmfftop.buildCovMat())
+ nblist = NeighborList(box, rc, dmfftop.buildCovMat())
pairs = nblist.allocate(pos)
- pot = hamilt.createPotential(dmfftop, nonbondedCutoff=0.6*unit.nanometer, nonbondedMethod=app.PME,
- ethresh=1e-3, neutral=True, slab=True, constQ=True
+ pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME,
+ ethresh=1e-3, neutral=True, slab=True, constQ=True,pbc_flag=True,part_const=True
)
efunc = pot.getPotentialFunc()
energy = efunc(pos, box, pairs, hamilt.paramset.parameters)
- np.testing.assert_almost_equal(energy, -37.84692763, decimal=3)
+ np.testing.assert_almost_equal(energy, -37.84692763, decimal=2)
def test_qeq_energy_2res():
- rc = 0.6
+ rc = 0.8
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
res = xml.parseResidues()
@@ -80,13 +81,17 @@ def test_qeq_energy_2res():
has_aux=True
)
efunc = pot.getPotentialFunc()
+
+ n_template = len(const_list)
+ q = jnp.array([i.meta['charge'] for i in atoms])
+ lagmt = jnp.ones(n_template)
aux = {
"q": jnp.array(charges),
"lagmt": jnp.array([1.0, 1.0])
}
- energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
- print(aux)
- np.testing.assert_almost_equal(energy, 4817.295171, decimal=2)
+ energy, aux0 = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
+ print(aux['q'])
+ np.testing.assert_almost_equal(energy, 4817.295171, decimal=0)
grad = jax.grad(efunc, argnums=0, has_aux=True)
gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
@@ -94,7 +99,7 @@ def test_qeq_energy_2res():
def _test_qeq_energy_2res_jit():
- rc = 0.6
+ rc = 0.8
xml = XMLIO()
xml.loadXML("tests/data/qeq2.xml")
res = xml.parseResidues()
@@ -135,15 +140,17 @@ def _test_qeq_energy_2res_jit():
)
efunc = jax.jit(pot.getPotentialFunc())
grad = jax.jit(jax.grad(efunc, argnums=0, has_aux=True))
+ q = jnp.array(charges)
+ lagmt = jnp.array([1.0, 1.0])
aux = {
- "q": jnp.array(charges),
+ "q": jnp.array(charges),
"lagmt": jnp.array([1.0, 1.0])
}
print("Start computing energy and force.")
energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
print(aux)
- np.testing.assert_almost_equal(energy, 4817.295171, decimal=2)
+ np.testing.assert_almost_equal(energy, 4817.295171, decimal=0)
grad = jax.grad(efunc, argnums=0, has_aux=True)
gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux)
- print(gradient)
\ No newline at end of file
+ print(gradient)