Skip to content

Commit

Permalink
update qeq, matrix invertion method and projected gradient method wer…
Browse files Browse the repository at this point in the history
…âre added, units of some paramters were unified
  • Loading branch information
gust-07 committed Sep 5, 2024
1 parent cdb4ac6 commit be90571
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 77 deletions.
266 changes: 215 additions & 51 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import jax.numpy as jnp
from ..common.constants import DIELECTRIC
from jax import grad, vmap
from jax import grad, value_and_grad, vmap, jacfwd, jacrev
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
from ..settings import PRECISION
Expand Down Expand Up @@ -69,15 +69,19 @@ def padding_consts(const_list, max_idx):

@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
constraint = (group_sum(q, const_list) - const_vals) * lagmt
return jnp.sum(constraint)
# constraint = (group_sum(q, const_list) - const_vals) * lagmt
# return jnp.sum(constraint)
return 0.0


@jit_condition()
def E_constP(q, lagmt, const_list, const_vals):
constraint = group_sum(q, const_list) * const_vals
return jnp.sum(constraint)

@jit_condition()
def E_noconst(q, lagmt, const_list, const_vals):
return 0.0

@vmap
@jit_condition()
Expand All @@ -87,13 +91,14 @@ def mask_to_zero(v, mask):
)


@jit_condition()
def E_sr(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
return 0.0


@jit_condition()
def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr2(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2))
pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
Expand All @@ -104,8 +109,9 @@ def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
return e_sr


@jit_condition()
def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr3(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(
eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2 + 1e-64
) # add eta to avoid division by zero
Expand All @@ -126,13 +132,13 @@ def E_site(chi, J, q):

@jit_condition()
def E_site2(chi, J, q):
ene = (chi * q + 0.5 * J * q**2) * 96.4869
ene = (chi * q + 0.5 * J * q**2) * 96.4869 #ev to kj/mol
return jnp.sum(ene)


@jit_condition()
def E_site3(chi, J, q):
ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
ene = chi * q + J* q**2 # kj/mol
return jnp.sum(ene)


Expand All @@ -149,7 +155,7 @@ def E_corr(pos, box, pairs, q, kappa, neutral_flag=True):
- Q_tot * (jnp.sum(q * pos[:, 2] ** 2))
- jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12
)
if neutral_flag:
if not neutral_flag:
# kappa = pme_potential.pme_force.kappa
pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC
e_corr_non = pre_corr_non * Q_tot**2
Expand Down Expand Up @@ -211,20 +217,47 @@ def __init__(
slab_flag: bool = False,
constQ: bool = True,
pbc_flag: bool = True,
part_const:bool = True,
has_aux=False,
):
self.has_aux = has_aux
self.part_const = part_const
const_vals = np.array(const_vals)
if neutral_flag:
const_vals = const_vals - np.sum(const_vals) / len(const_vals)
#if neutral_flag:
# const_vals = const_vals - np.sum(const_vals) / len(const_vals)
self.const_vals = jnp.array(const_vals)
assert len(const_list) == len(
const_vals
), "const_list and const_vals must have the same length"
n_atoms = len(init_q)
self.const_list = padding_consts(const_list, n_atoms)

const_mat = np.zeros((len(const_list), n_atoms))
for ncl, cl in enumerate(const_list):
const_mat[ncl][cl] = 1
self.const_mat = jnp.array(const_mat)

if len(const_list) != 0:
self.const_list = padding_consts(const_list, n_atoms)
#if fix part charges
self.all_const_list = self.const_list[jnp.where(self.const_list < n_atoms)]
else:
self.const_list = np.array(const_list)
self.all_const_list = self.const_list

all_fix_list = jnp.setdiff1d(jnp.array(range(n_atoms)),self.all_const_list)
fix_mat = np.zeros((len(all_fix_list),n_atoms))
for i, j in enumerate(all_fix_list):
fix_mat[i][j] = 1
self.all_fix_list = jnp.array(all_fix_list)
self.fix_mat = jnp.array(fix_mat)

self.init_q = jnp.array(init_q)
self.init_lagmt = jnp.ones((len(const_list),))

self.init_energy = True #init charge by hession inversion method
self.icount = 0
self.hessinv_stride = 1
self.qupdate_stride = 1

self.damp_mod = damp_mod
self.neutral_flag = neutral_flag
Expand All @@ -234,6 +267,8 @@ def __init__(

if constQ:
e_constraint = E_constQ
elif not part_const:
e_constraint = E_noconst
else:
e_constraint = E_constP
self.e_constraint = e_constraint
Expand Down Expand Up @@ -318,11 +353,15 @@ def coul_energy(positions, box, pairs, q, mscales):

self.coul_energy = coul_energy


def generate_get_energy(self):
@jit_condition()
def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
def E_full(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
if self.part_const:
e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
else:
e1 = 0
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, buffer_scales, self.pbc_flag)
e3 = self.e_site(chi, J, q)
e4 = self.coul_energy(pos, box, pairs, q, mscales)
if self.slab_flag:
Expand All @@ -336,70 +375,195 @@ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
grad_E_full = grad(E_full, argnums=(0, 1))

@jit_condition()
def E_grads(
b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
n_const = len(self.const_vals)
q = b_value[:-n_const]
lagmt = b_value[-n_const:]

g1, g2 = grad_E_full(
q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
g = jnp.concatenate((g1, g2))
return g
def E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
h = jacfwd(jacrev(E_full, argnums=(0)))(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)
return h

def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
@jit_condition()
def get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
pos = positions
ds = ds_pairs(pos, box, pairs, self.pbc_flag)
buffer_scales = pair_buffer_scales(pairs)

n_const = len(self.init_lagmt)
b_vector = jnp.concatenate((-chi, self.const_vals)) #For E_site3

if self.has_aux:
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
# if JAXOPT_OLD:
if True:
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
q = self.init_q
lagmt = self.init_lagmt
B = E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)

if self.part_const:
C = jnp.eye(len(q))
A = C.at[self.all_const_list].set(B[self.all_const_list])
else:
rf = jaxopt.Broyden(fun=E_grads, tol=1e-10)
b_0, _ = rf.run(
b_value,
A = self.fix_mat

b_vector = b_vector.at[self.all_fix_list].set(q[self.all_fix_list])


m0 = jnp.concatenate((A,self.const_mat),axis=0)
n0 = jnp.concatenate((jnp.transpose(self.const_mat),jnp.zeros((n_const,n_const))),axis=0)

M = jnp.concatenate((m0,n0),axis=1)

q_0 = jnp.linalg.solve(M,b_vector)
q = q_0[:len(pos)]
lagmt = q_0[len(pos):]
energy = E_full(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
b_0 = jax.lax.stop_gradient(b_0)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]
self.init_energy = False
# self.icount = self.icount + 1
if self.has_aux:
aux["q"] = q_0
aux["A"] = A
aux["m0"] = m0
aux["n0"] = n0
aux["b_vector"] = b_vector
aux["init_energy"] = self.init_energy
aux["icount"] = self.icount
return energy, aux
else:
return energy


# @jit_condition()
def get_proj_grad(func, constraint_matrix, has_aux=False):
def value_and_proj_grad(*arg, **kwargs):
value, grad = value_and_grad(func, has_aux=has_aux)(*arg, **kwargs)
a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1))
b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True)
delta_grad = jnp.matmul((a / b).T, constraint_matrix)
proj_grad = grad - delta_grad.reshape(-1)
return value, proj_grad
return value_and_proj_grad

@jit_condition()
def get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
if self.init_energy:
if self.has_aux:
energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy, aux
else:
energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy
if not self.icount % self.hessinv_stride :
if self.has_aux:
energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy, aux
else:
energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy

func = get_proj_grad(E_full,self.const_mat)
solver = jaxopt.LBFGS(
fun=func,
value_and_grad=True,
tol=1e-2,
)
pos = positions
buffer_scales = pair_buffer_scales(pairs)
if self.has_aux:
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
q = self.init_q
lagmt = self.init_lagmt

res = solver.run(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
buffer_scales,
mscales,
)
q_opt = res.params
energy = E_full(
q_0,
lagmt_0,
q_opt,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
if self.has_aux:
aux["q"] = q_0
aux["lagmt"] = lagmt_0
aux["q"] = aux['q'].at[:len(pos)].set(q_opt)
return energy, aux
else:
return energy
# @jit_condition()
def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
if self.has_aux :
if "const_vals" in aux.keys():
self.const_vals = aux["const_vals"]
if "hessinv_stride" in aux.keys():
self.hessinv_stride = aux["hessinv_stride"]
if "qupdate_stride" in aux.keys():
self.qupdate_stride = aux["qupdate_stride"]
if not self.icount % self.qupdate_stride :
if self.has_aux:
# aux["q"] = aux['q'].at[:len(pos)].set(q)
energy, aux = get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux)
self.icount = self.icount + 1
aux["icount"] = self.icount
return energy, aux
else:
self.icount = self.icount + 1
energy = get_step_energy(positions, box, pairs, mscales, eta, chi, J )
return energy

else:
self.icount = self.icount + 1
# print(self.icount)
pos = positions
buffer_scales = pair_buffer_scales(pairs)
if self.has_aux:
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
q = self.init_q
lagmt = self.init_lagmt
energy = E_full(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
buffer_scales,
mscales,
)

if self.has_aux:
aux = aux
# aux["q"] = aux['q'].at[:len(pos)].set(q)
aux["icount"] = self.icount
return energy, aux
else:
return energy

return get_energy

Loading

0 comments on commit be90571

Please sign in to comment.