Skip to content

Commit

Permalink
AD for gauge actions
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jul 11, 2024
1 parent 37880c1 commit 3bace57
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/gpt/ad/reverse/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ def functional(self, *arguments):
def get_grid(self):
return self.value.grid

def get_otype(self):
return self.value.otype

def get_real(self):
def getter(y):
return y.real
Expand All @@ -368,6 +371,7 @@ def setter(y, z):
return self.project(getter, setter)

grid = property(get_grid)
otype = property(get_otype)
real = property(get_real)


Expand Down
1 change: 1 addition & 0 deletions lib/gpt/qcd/gauge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
field_strength,
differentiable_topology,
differentiable_energy_density,
differentiable_P_and_R,
)
from gpt.qcd.gauge.topology import topological_charge_5LI
from gpt.qcd.gauge.staples import staple
Expand Down
22 changes: 22 additions & 0 deletions lib/gpt/qcd/gauge/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,25 @@ def differentiable_energy_density(aU):
return (-1.0 / grid.gsites) * g.sum(g.trace(res))


def differentiable_P_and_R(aU):
Nd = len(aU)
grid = aU[0].grid
ndim = aU[0].otype.shape[0]
res_P = None
res_R = None
for mu in range(Nd):
for nu in range(Nd):
if mu == nu:
continue

staple_up, staple_down = staples(aU, mu, nu)

P = g.sum(g.trace(aU[mu] * staple_up))
R = g.sum(g.trace(g.adj(staple_down) * staple_up))

res_P = P if res_P is None else P + res_P
res_R = R if res_R is None else R + res_R

res_P = (res_P + g.adj(res_P)) * (0.5 / (Nd - 1) / Nd / ndim / grid.gsites)
res_R = (res_R + g.adj(res_R)) * (0.5 / (Nd - 1) / Nd / ndim / grid.gsites)
return res_P, res_R
30 changes: 30 additions & 0 deletions tests/ad/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,35 @@ def real(x):
nid += 1


# improved action test
U = g.qcd.gauge.random(grid, rng)
U_2 = [rad.node(g.copy(u)) for u in U]
P, R = g.qcd.gauge.differentiable_P_and_R(U_2)
a1 = g.qcd.gauge.action.iwasaki(1.0)
c1 = -0.331
c0 = 1.0 - 8.0 * c1
Nd = len(U)
ndim = U[0].otype.shape[0]
vol = grid.gsites

A = vol * (c0 * (1.0 - P) * (Nd - 1) * Nd / 2.0 + c1 * (1.0 - R) * (Nd - 1) * Nd)

a1p = A.functional(*U_2)
eps = abs(a1p(U) / a1(U) - 1)
g.message("Iwasaki test", eps)
assert eps < 1e-10

t0 = g.time()
grad = a1.gradient(U, U)
t1 = g.time()
gradp = a1p.gradient(U, U)
t2 = g.time()
g.message(f"Force time {t1 - t0}, AD force time {t2 - t1}")
for mu in range(4):
eps2 = g.norm2(grad[mu] - gradp[mu]) / g.norm2(grad[mu])
g.message("Force test", mu, eps2)
assert eps2 < 1e-20

#####################################
# forward AD tests
#####################################
Expand Down Expand Up @@ -375,6 +404,7 @@ def plaquette(U):
g.message(f"Numerical action gradient [{mu}] derivative test: {err}")
assert err < 1e-5


# test simple combination of forward and reverse
a = g.ad.forward.make(On, 1.3333 + 3.21j, dbeta, 2.1 + 0.7j)
b = g.ad.forward.make(On, 0.9 + 0.756j, dbeta, 1.3j + 0.21)
Expand Down

0 comments on commit 3bace57

Please sign in to comment.