From 404e2a29d155cde1b5f9e046254f3bd38a2a8d9e Mon Sep 17 00:00:00 2001
From: Christoph Lehner <christoph@lhnr.de>
Date: Wed, 3 Jul 2024 15:35:15 +0200
Subject: [PATCH] more speed

---
 lib/gpt/qcd/gauge/smear/stout.py | 32 ++++++++++++++++++++++++++++----
 1 file changed, 28 insertions(+), 4 deletions(-)

diff --git a/lib/gpt/qcd/gauge/smear/stout.py b/lib/gpt/qcd/gauge/smear/stout.py
index ecb36a82..f700bb3c 100644
--- a/lib/gpt/qcd/gauge/smear/stout.py
+++ b/lib/gpt/qcd/gauge/smear/stout.py
@@ -53,6 +53,7 @@ class stout(diffeomorphism):
     def __init__(self, params):
         self.params = params
         self.verbose = g.default.is_verbose("stout_performance")
+        self.stencil = None
 
     # apply the smearing
     def __call__(self, fields):
@@ -108,6 +109,29 @@ def jacobian(self, fields, fields_prime, src):
             tt("expr")
             dst[mu] @= Sigma_prime[mu] * exp_iQ[mu] + g.adj(C[mu]) * 1j * Lambda[mu]
 
+        # create all shifted U and lambda fields at once
+        U_shifted = [g.lattice(U[0]) for mu in range(nd * nd)]
+        Lambda_shifted = [g.lattice(Lambda[0]) for mu in range(nd * nd)]
+        if self.stencil is None:
+            directions = [tuple([1 if mu == nu else 0 for mu in range(nd)]) for nu in range(nd)]
+            self.stencil = g.stencil.matrix(
+                U[0],
+                directions,
+                [
+                    (nd * mu + nu, -1, 1.0, [(mu + nd * nd * 2, nu, 0)])
+                    for mu in range(nd)
+                    for nu in range(nd)
+                    if mu != nu
+                ]
+                + [
+                    (nd * mu + nu + nd * nd, -1, 1.0, [(mu + nd * nd * 2 + nd, nu, 0)])
+                    for mu in range(nd)
+                    for nu in range(nd)
+                    if mu != nu
+                ],
+            )
+        self.stencil(*U_shifted, *Lambda_shifted, *U, *Lambda)
+
         for mu in range(nd):
             for nu in range(nd):
                 if mu == nu:
@@ -118,10 +142,10 @@ def jacobian(self, fields, fields_prime, src):
 
                 if abs(rho_nu_mu) != 0.0 or abs(rho_mu_nu) != 0.0:
                     tt("cshift")
-                    U_nu_x_plus_mu = g.cshift(U[nu], mu, 1)
-                    U_mu_x_plus_nu = g.cshift(U[mu], nu, 1)
-                    Lambda_nu_x_plus_mu = g.cshift(Lambda[nu], mu, 1)
-                    Lambda_mu_x_plus_nu = g.cshift(Lambda[mu], nu, 1)
+                    U_nu_x_plus_mu = U_shifted[nd * nu + mu]
+                    U_mu_x_plus_nu = U_shifted[nd * mu + nu]
+                    Lambda_nu_x_plus_mu = Lambda_shifted[nd * nu + mu]
+                    Lambda_mu_x_plus_nu = Lambda_shifted[nd * mu + nu]
 
                     tt("expr")