From 404e2a29d155cde1b5f9e046254f3bd38a2a8d9e Mon Sep 17 00:00:00 2001 From: Christoph Lehner <christoph@lhnr.de> Date: Wed, 3 Jul 2024 15:35:15 +0200 Subject: [PATCH] more speed --- lib/gpt/qcd/gauge/smear/stout.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/lib/gpt/qcd/gauge/smear/stout.py b/lib/gpt/qcd/gauge/smear/stout.py index ecb36a82..f700bb3c 100644 --- a/lib/gpt/qcd/gauge/smear/stout.py +++ b/lib/gpt/qcd/gauge/smear/stout.py @@ -53,6 +53,7 @@ class stout(diffeomorphism): def __init__(self, params): self.params = params self.verbose = g.default.is_verbose("stout_performance") + self.stencil = None # apply the smearing def __call__(self, fields): @@ -108,6 +109,29 @@ def jacobian(self, fields, fields_prime, src): tt("expr") dst[mu] @= Sigma_prime[mu] * exp_iQ[mu] + g.adj(C[mu]) * 1j * Lambda[mu] + # create all shifted U and lambda fields at once + U_shifted = [g.lattice(U[0]) for mu in range(nd * nd)] + Lambda_shifted = [g.lattice(Lambda[0]) for mu in range(nd * nd)] + if self.stencil is None: + directions = [tuple([1 if mu == nu else 0 for mu in range(nd)]) for nu in range(nd)] + self.stencil = g.stencil.matrix( + U[0], + directions, + [ + (nd * mu + nu, -1, 1.0, [(mu + nd * nd * 2, nu, 0)]) + for mu in range(nd) + for nu in range(nd) + if mu != nu + ] + + [ + (nd * mu + nu + nd * nd, -1, 1.0, [(mu + nd * nd * 2 + nd, nu, 0)]) + for mu in range(nd) + for nu in range(nd) + if mu != nu + ], + ) + self.stencil(*U_shifted, *Lambda_shifted, *U, *Lambda) + for mu in range(nd): for nu in range(nd): if mu == nu: @@ -118,10 +142,10 @@ def jacobian(self, fields, fields_prime, src): if abs(rho_nu_mu) != 0.0 or abs(rho_mu_nu) != 0.0: tt("cshift") - U_nu_x_plus_mu = g.cshift(U[nu], mu, 1) - U_mu_x_plus_nu = g.cshift(U[mu], nu, 1) - Lambda_nu_x_plus_mu = g.cshift(Lambda[nu], mu, 1) - Lambda_mu_x_plus_nu = g.cshift(Lambda[mu], nu, 1) + U_nu_x_plus_mu = U_shifted[nd * nu + mu] + U_mu_x_plus_nu = U_shifted[nd * mu + nu] + Lambda_nu_x_plus_mu = Lambda_shifted[nd * nu + mu] + Lambda_mu_x_plus_nu = Lambda_shifted[nd * mu + nu] tt("expr")