diff --git a/dmff/admp/qeq.py b/dmff/admp/qeq.py index b263d2b5f..dd876db4b 100644 --- a/dmff/admp/qeq.py +++ b/dmff/admp/qeq.py @@ -1,7 +1,7 @@ import numpy as np import jax.numpy as jnp from ..common.constants import DIELECTRIC -from jax import grad, vmap +from jax import grad, value_and_grad, vmap, jacfwd, jacrev from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce from typing import Tuple, List from ..settings import PRECISION @@ -69,8 +69,9 @@ def padding_consts(const_list, max_idx): @jit_condition() def E_constQ(q, lagmt, const_list, const_vals): - constraint = (group_sum(q, const_list) - const_vals) * lagmt - return jnp.sum(constraint) + # constraint = (group_sum(q, const_list) - const_vals) * lagmt + # return jnp.sum(constraint) + return 0.0 @jit_condition() @@ -78,6 +79,9 @@ def E_constP(q, lagmt, const_list, const_vals): constraint = group_sum(q, const_list) * const_vals return jnp.sum(constraint) +@jit_condition() +def E_noconst(q, lagmt, const_list, const_vals): + return 0.0 @vmap @jit_condition() @@ -87,13 +91,14 @@ def mask_to_zero(v, mask): ) -@jit_condition() -def E_sr(pos, box, pairs, q, eta, ds, buffer_scales): +@jit_condition(static_argnums=[6]) +def E_sr(pos, box, pairs, q, eta, buffer_scales, pbc_flag): return 0.0 -@jit_condition() -def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales): +@jit_condition(static_argnums=[6]) +def E_sr2(pos, box, pairs, q, eta, buffer_scales, pbc_flag): + ds = ds_pairs(pos, box, pairs, pbc_flag) etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2)) pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC @@ -104,8 +109,9 @@ def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales): return e_sr -@jit_condition() -def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales): +@jit_condition(static_argnums=[6]) +def E_sr3(pos, box, pairs, q, eta, buffer_scales, pbc_flag): + ds = ds_pairs(pos, box, pairs, pbc_flag) etasqrt = jnp.sqrt( eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2 + 1e-64 ) # add eta to avoid division by zero @@ -126,13 +132,13 @@ def E_site(chi, J, q): @jit_condition() def E_site2(chi, J, q): - ene = (chi * q + 0.5 * J * q**2) * 96.4869 + ene = (chi * q + 0.5 * J * q**2) * 96.4869 #ev to kj/mol return jnp.sum(ene) @jit_condition() def E_site3(chi, J, q): - ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi + ene = chi * q + J* q**2 # kj/mol return jnp.sum(ene) @@ -149,7 +155,7 @@ def E_corr(pos, box, pairs, q, kappa, neutral_flag=True): - Q_tot * (jnp.sum(q * pos[:, 2] ** 2)) - jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12 ) - if neutral_flag: + if not neutral_flag: # kappa = pme_potential.pme_force.kappa pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC e_corr_non = pre_corr_non * Q_tot**2 @@ -211,20 +217,47 @@ def __init__( slab_flag: bool = False, constQ: bool = True, pbc_flag: bool = True, + part_const:bool = True, has_aux=False, ): self.has_aux = has_aux + self.part_const = part_const const_vals = np.array(const_vals) - if neutral_flag: - const_vals = const_vals - np.sum(const_vals) / len(const_vals) + #if neutral_flag: + # const_vals = const_vals - np.sum(const_vals) / len(const_vals) self.const_vals = jnp.array(const_vals) assert len(const_list) == len( const_vals ), "const_list and const_vals must have the same length" n_atoms = len(init_q) - self.const_list = padding_consts(const_list, n_atoms) + + const_mat = np.zeros((len(const_list), n_atoms)) + for ncl, cl in enumerate(const_list): + const_mat[ncl][cl] = 1 + self.const_mat = jnp.array(const_mat) + + if len(const_list) != 0: + self.const_list = padding_consts(const_list, n_atoms) + #if fix part charges + self.all_const_list = self.const_list[jnp.where(self.const_list < n_atoms)] + else: + self.const_list = np.array(const_list) + self.all_const_list = self.const_list + + all_fix_list = jnp.setdiff1d(jnp.array(range(n_atoms)),self.all_const_list) + fix_mat = np.zeros((len(all_fix_list),n_atoms)) + for i, j in enumerate(all_fix_list): + fix_mat[i][j] = 1 + self.all_fix_list = jnp.array(all_fix_list) + self.fix_mat = jnp.array(fix_mat) + self.init_q = jnp.array(init_q) self.init_lagmt = jnp.ones((len(const_list),)) + + self.init_energy = True #init charge by hession inversion method + self.icount = 0 + self.hessinv_stride = 1 + self.qupdate_stride = 1 self.damp_mod = damp_mod self.neutral_flag = neutral_flag @@ -234,6 +267,8 @@ def __init__( if constQ: e_constraint = E_constQ + elif not part_const: + e_constraint = E_noconst else: e_constraint = E_constP self.e_constraint = e_constraint @@ -318,11 +353,15 @@ def coul_energy(positions, box, pairs, q, mscales): self.coul_energy = coul_energy + def generate_get_energy(self): @jit_condition() - def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales): - e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals) - e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales) + def E_full(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales): + if self.part_const: + e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals) + else: + e1 = 0 + e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, buffer_scales, self.pbc_flag) e3 = self.e_site(chi, J, q) e4 = self.coul_energy(pos, box, pairs, q, mscales) if self.slab_flag: @@ -336,70 +375,195 @@ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales): grad_E_full = grad(E_full, argnums=(0, 1)) @jit_condition() - def E_grads( - b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales - ): - n_const = len(self.const_vals) - q = b_value[:-n_const] - lagmt = b_value[-n_const:] - - g1, g2 = grad_E_full( - q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales - ) - g = jnp.concatenate((g1, g2)) - return g + def E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales): + h = jacfwd(jacrev(E_full, argnums=(0)))(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales) + return h - def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None): + @jit_condition() + def get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux=None): pos = positions - ds = ds_pairs(pos, box, pairs, self.pbc_flag) buffer_scales = pair_buffer_scales(pairs) n_const = len(self.init_lagmt) + b_vector = jnp.concatenate((-chi, self.const_vals)) #For E_site3 + if self.has_aux: - b_value = jnp.concatenate((aux["q"], aux["lagmt"])) + q = aux["q"][:len(pos)] + lagmt = aux["q"][len(pos):] else: - b_value = jnp.concatenate([self.init_q, self.init_lagmt]) - # if JAXOPT_OLD: - if True: - rf = jaxopt.ScipyRootFinding( - optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10 - ) + q = self.init_q + lagmt = self.init_lagmt + B = E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales) + + if self.part_const: + C = jnp.eye(len(q)) + A = C.at[self.all_const_list].set(B[self.all_const_list]) else: - rf = jaxopt.Broyden(fun=E_grads, tol=1e-10) - b_0, _ = rf.run( - b_value, + A = self.fix_mat + + b_vector = b_vector.at[self.all_fix_list].set(q[self.all_fix_list]) + + + m0 = jnp.concatenate((A,self.const_mat),axis=0) + n0 = jnp.concatenate((jnp.transpose(self.const_mat),jnp.zeros((n_const,n_const))),axis=0) + + M = jnp.concatenate((m0,n0),axis=1) + + q_0 = jnp.linalg.solve(M,b_vector) + q = q_0[:len(pos)] + lagmt = q_0[len(pos):] + energy = E_full( + q, + lagmt, chi, J, positions, box, pairs, eta, - ds, buffer_scales, mscales, ) - b_0 = jax.lax.stop_gradient(b_0) - q_0 = b_0[:-n_const] - lagmt_0 = b_0[-n_const:] + self.init_energy = False + # self.icount = self.icount + 1 + if self.has_aux: + aux["q"] = q_0 + aux["A"] = A + aux["m0"] = m0 + aux["n0"] = n0 + aux["b_vector"] = b_vector + aux["init_energy"] = self.init_energy + aux["icount"] = self.icount + return energy, aux + else: + return energy + + # @jit_condition() + def get_proj_grad(func, constraint_matrix, has_aux=False): + def value_and_proj_grad(*arg, **kwargs): + value, grad = value_and_grad(func, has_aux=has_aux)(*arg, **kwargs) + a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1)) + b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True) + delta_grad = jnp.matmul((a / b).T, constraint_matrix) + proj_grad = grad - delta_grad.reshape(-1) + return value, proj_grad + return value_and_proj_grad + + @jit_condition() + def get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux=None): + if self.init_energy: + if self.has_aux: + energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux) + return energy, aux + else: + energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux) + return energy + if not self.icount % self.hessinv_stride : + if self.has_aux: + energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux) + return energy, aux + else: + energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux) + return energy + + func = get_proj_grad(E_full,self.const_mat) + solver = jaxopt.LBFGS( + fun=func, + value_and_grad=True, + tol=1e-2, + ) + pos = positions + buffer_scales = pair_buffer_scales(pairs) + if self.has_aux: + q = aux["q"][:len(pos)] + lagmt = aux["q"][len(pos):] + else: + q = self.init_q + lagmt = self.init_lagmt + + res = solver.run( + q, + lagmt, + chi, + J, + positions, + box, + pairs, + eta, + buffer_scales, + mscales, + ) + q_opt = res.params energy = E_full( - q_0, - lagmt_0, + q_opt, + lagmt, chi, J, positions, box, pairs, eta, - ds, buffer_scales, mscales, ) if self.has_aux: - aux["q"] = q_0 - aux["lagmt"] = lagmt_0 + aux["q"] = aux['q'].at[:len(pos)].set(q_opt) return energy, aux else: return energy + # @jit_condition() + def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None): + if self.has_aux : + if "const_vals" in aux.keys(): + self.const_vals = aux["const_vals"] + if "hessinv_stride" in aux.keys(): + self.hessinv_stride = aux["hessinv_stride"] + if "qupdate_stride" in aux.keys(): + self.qupdate_stride = aux["qupdate_stride"] + if not self.icount % self.qupdate_stride : + if self.has_aux: + # aux["q"] = aux['q'].at[:len(pos)].set(q) + energy, aux = get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux) + self.icount = self.icount + 1 + aux["icount"] = self.icount + return energy, aux + else: + self.icount = self.icount + 1 + energy = get_step_energy(positions, box, pairs, mscales, eta, chi, J ) + return energy + else: + self.icount = self.icount + 1 + # print(self.icount) + pos = positions + buffer_scales = pair_buffer_scales(pairs) + if self.has_aux: + q = aux["q"][:len(pos)] + lagmt = aux["q"][len(pos):] + else: + q = self.init_q + lagmt = self.init_lagmt + energy = E_full( + q, + lagmt, + chi, + J, + positions, + box, + pairs, + eta, + buffer_scales, + mscales, + ) + + if self.has_aux: + aux = aux + # aux["q"] = aux['q'].at[:len(pos)].set(q) + aux["icount"] = self.icount + return energy, aux + else: + return energy + return get_energy + diff --git a/dmff/generators/qeq.py b/dmff/generators/qeq.py index d8e599ce2..5faaa1512 100644 --- a/dmff/generators/qeq.py +++ b/dmff/generators/qeq.py @@ -148,6 +148,7 @@ def createPotential( neutral_flag = kwargs.get("neutral", True) slab_flag = kwargs.get("slab", False) constQ = kwargs.get("constQ", True) + part_const = kwargs.get("part_const", True) # top info n_atoms = topdata.getNumAtoms() @@ -188,6 +189,7 @@ def createPotential( slab_flag=slab_flag, constQ=constQ, pbc_flag=(not isNoCut), + part_const=part_const, has_aux=has_aux, ) qeq_energy = qeq_force.generate_get_energy() @@ -206,13 +208,12 @@ def potential_fn( eta = params[self.name]["eta"][map_idx] chi = params[self.name]["chi"][map_idx] J = params[self.name]["J"][map_idx] - if has_aux: qeq_energy0, aux = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J, aux) # return pme_energy + qeq_energy0 return qeq_energy0, aux else: - qeq_energy0 = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J) + qeq_energy0 = qeq_energy(positions, box, pairs, mscales_coul, eta, chi, J ) # return pme_energy + qeq_energy0 return qeq_energy0 diff --git a/tests/data/qeq.xml b/tests/data/qeq.xml index a017ea93f..120312626 100644 --- a/tests/data/qeq.xml +++ b/tests/data/qeq.xml @@ -156,10 +156,10 @@ - - - - - + + + + + diff --git a/tests/data/qeq2.xml b/tests/data/qeq2.xml index e8e79bec1..d93e876d2 100644 --- a/tests/data/qeq2.xml +++ b/tests/data/qeq2.xml @@ -163,12 +163,12 @@ - - - - - - - + + + + + + + diff --git a/tests/test_admp/test_qeq.py b/tests/test_admp/test_qeq.py index 296801eba..54240ab95 100644 --- a/tests/test_admp/test_qeq.py +++ b/tests/test_admp/test_qeq.py @@ -10,6 +10,7 @@ def test_qeq_energy(): + rc = 0.8 xml = XMLIO() xml.loadXML("tests/data/qeq.xml") res = xml.parseResidues() @@ -28,19 +29,19 @@ def test_qeq_energy(): a.meta["charge"] = charges[na] a.meta["type"] = types[na] - nblist = NeighborList(box, 0.6, dmfftop.buildCovMat()) + nblist = NeighborList(box, rc, dmfftop.buildCovMat()) pairs = nblist.allocate(pos) - pot = hamilt.createPotential(dmfftop, nonbondedCutoff=0.6*unit.nanometer, nonbondedMethod=app.PME, - ethresh=1e-3, neutral=True, slab=True, constQ=True + pot = hamilt.createPotential(dmfftop, nonbondedCutoff=rc*unit.nanometer, nonbondedMethod=app.PME, + ethresh=1e-3, neutral=True, slab=True, constQ=True,pbc_flag=True,part_const=True ) efunc = pot.getPotentialFunc() energy = efunc(pos, box, pairs, hamilt.paramset.parameters) - np.testing.assert_almost_equal(energy, -37.84692763, decimal=3) + np.testing.assert_almost_equal(energy, -37.84692763, decimal=2) def test_qeq_energy_2res(): - rc = 0.6 + rc = 0.8 xml = XMLIO() xml.loadXML("tests/data/qeq2.xml") res = xml.parseResidues() @@ -80,13 +81,17 @@ def test_qeq_energy_2res(): has_aux=True ) efunc = pot.getPotentialFunc() + + n_template = len(const_list) + q = jnp.array([i.meta['charge'] for i in atoms]) + lagmt = jnp.ones(n_template) aux = { "q": jnp.array(charges), "lagmt": jnp.array([1.0, 1.0]) } - energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux) - print(aux) - np.testing.assert_almost_equal(energy, 4817.295171, decimal=2) + energy, aux0 = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux) + print(aux['q']) + np.testing.assert_almost_equal(energy, 4817.295171, decimal=0) grad = jax.grad(efunc, argnums=0, has_aux=True) gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux) @@ -94,7 +99,7 @@ def test_qeq_energy_2res(): def _test_qeq_energy_2res_jit(): - rc = 0.6 + rc = 0.8 xml = XMLIO() xml.loadXML("tests/data/qeq2.xml") res = xml.parseResidues() @@ -135,15 +140,17 @@ def _test_qeq_energy_2res_jit(): ) efunc = jax.jit(pot.getPotentialFunc()) grad = jax.jit(jax.grad(efunc, argnums=0, has_aux=True)) + q = jnp.array(charges) + lagmt = jnp.array([1.0, 1.0]) aux = { - "q": jnp.array(charges), + "q": jnp.array(charges), "lagmt": jnp.array([1.0, 1.0]) } print("Start computing energy and force.") energy, aux = efunc(pos, box, pairs, hamilt.paramset.parameters, aux=aux) print(aux) - np.testing.assert_almost_equal(energy, 4817.295171, decimal=2) + np.testing.assert_almost_equal(energy, 4817.295171, decimal=0) grad = jax.grad(efunc, argnums=0, has_aux=True) gradient, aux = grad(pos, box, pairs, hamilt.paramset.parameters, aux=aux) - print(gradient) \ No newline at end of file + print(gradient)