diff --git a/include/array.h b/include/array.h index 3005087c85..a60d78db23 100644 --- a/include/array.h +++ b/include/array.h @@ -24,6 +24,12 @@ namespace quda array &operator=(const array &) = default; array &operator=(array &&) = default; + + template constexpr array &operator=(const array &other) + { + for (int i = 0; i < n; i++) data[i] = other[i]; + return *this; + } }; template std::ostream &operator<<(std::ostream &output, const array &a) diff --git a/include/blas_helper.cuh b/include/blas_helper.cuh index 80e974c0c9..100f1387cf 100644 --- a/include/blas_helper.cuh +++ b/include/blas_helper.cuh @@ -23,48 +23,6 @@ namespace quda static constexpr bool V = V_; }; - __host__ __device__ inline double set(double &x) { return x; } - __host__ __device__ inline double2 set(double2 &x) { return x; } - __host__ __device__ inline double3 set(double3 &x) { return x; } - __host__ __device__ inline double4 set(double4 &x) { return x; } - __host__ __device__ inline void sum(double &a, double &b) { a += b; } - __host__ __device__ inline void sum(double2 &a, double2 &b) - { - a.x += b.x; - a.y += b.y; - } - __host__ __device__ inline void sum(double3 &a, double3 &b) - { - a.x += b.x; - a.y += b.y; - a.z += b.z; - } - __host__ __device__ inline void sum(double4 &a, double4 &b) - { - a.x += b.x; - a.y += b.y; - a.z += b.z; - a.w += b.w; - } - -#ifdef QUAD_SUM - __host__ __device__ inline double set(doubledouble &a) { return a.head(); } - __host__ __device__ inline double2 set(doubledouble2 &a) { return make_double2(a.x.head(), a.y.head()); } - __host__ __device__ inline double3 set(doubledouble3 &a) { return make_double3(a.x.head(), a.y.head(), a.z.head()); } - __host__ __device__ inline void sum(double &a, doubledouble &b) { a += b.head(); } - __host__ __device__ inline void sum(double2 &a, doubledouble2 &b) - { - a.x += b.x.head(); - a.y += b.y.head(); - } - __host__ __device__ inline void sum(double3 &a, doubledouble3 &b) - { - a.x += b.x.head(); - a.y += b.y.head(); - a.z += b.z.head(); - } -#endif - // Vector types used for AoS load-store on CPU template <> struct VectorType { using type = array; diff --git a/include/blas_quda.h b/include/blas_quda.h index 8df40df452..8bb711af38 100644 --- a/include/blas_quda.h +++ b/include/blas_quda.h @@ -49,14 +49,14 @@ namespace quda { @param[in] x input vector @param[out] y output vector */ - void axy(double a, const ColorSpinorField &x, ColorSpinorField &y); + void axy(real_t a, const ColorSpinorField &x, ColorSpinorField &y); /** @brief Apply the rescale operation x = a * x @param[in] a scalar multiplier @param[in] x input vector */ - inline void ax(double a, ColorSpinorField &x) { axy(a, x, x); } + inline void ax(real_t a, ColorSpinorField &x) { axy(a, x, x); } /** @brief Apply the operation z = a * x + b * y @@ -66,7 +66,7 @@ namespace quda { @param[in] y input vector @param[out] z output vector */ - void axpbyz(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, ColorSpinorField &z); + void axpbyz(real_t a, const ColorSpinorField &x, real_t b, const ColorSpinorField &y, ColorSpinorField &z); /** @brief Apply the operation y += x @@ -88,7 +88,7 @@ namespace quda { @param[in] x input vector @param[in,out] y update vector */ - inline void axpy(double a, const ColorSpinorField &x, ColorSpinorField &y) { axpbyz(a, x, 1.0, y, y); } + inline void axpy(real_t a, const ColorSpinorField &x, ColorSpinorField &y) { axpbyz(a, x, 1.0, y, y); } /** @brief Apply the operation y = a * x + b * y @@ -97,7 +97,7 @@ namespace quda { @param[in] b scalar multiplier @param[in,out] y update vector */ - inline void axpby(double a, const ColorSpinorField &x, double b, ColorSpinorField &y) { axpbyz(a, x, b, y, y); } + inline void axpby(real_t a, const ColorSpinorField &x, real_t b, ColorSpinorField &y) { axpbyz(a, x, b, y, y); } /** @brief Apply the operation y = x + a * y @@ -105,7 +105,7 @@ namespace quda { @param[in] a scalar multiplier @param[in,out] y update vector */ - inline void xpay(const ColorSpinorField &x, double a, ColorSpinorField &y) { axpbyz(1.0, x, a, y, y); } + inline void xpay(const ColorSpinorField &x, real_t a, ColorSpinorField &y) { axpbyz(1.0, x, a, y, y); } /** @brief Apply the operation z = x + a * y @@ -114,7 +114,7 @@ namespace quda { @param[in] y update vector @param[out] z output vector */ - inline void xpayz(const ColorSpinorField &x, double a, const ColorSpinorField &y, ColorSpinorField &z) { axpbyz(1.0, x, a, y, z); } + inline void xpayz(const ColorSpinorField &x, real_t a, const ColorSpinorField &y, ColorSpinorField &z) { axpbyz(1.0, x, a, y, z); } /** @brief Apply the operation y = a * x + y, x = z + b * x @@ -124,7 +124,7 @@ namespace quda { @param[in] z input vector @param[in] b scalar multiplier */ - void axpyZpbx(double a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z, double b); + void axpyZpbx(real_t a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z, real_t b); /** @brief Apply the operation y = a * x + y, x = b * z + c * x @@ -135,7 +135,7 @@ namespace quda { @param[in] z input vector @param[in] c scalar multiplier */ - void axpyBzpcx(double a, ColorSpinorField& x, ColorSpinorField& y, double b, const ColorSpinorField& z, double c); + void axpyBzpcx(real_t a, ColorSpinorField& x, ColorSpinorField& y, real_t b, const ColorSpinorField& z, real_t c); /** @brief Apply the operation w = a * x + b * y + c * z @@ -147,8 +147,8 @@ namespace quda { @param[in] z input vector @param[out] w output vector */ - void axpbypczw(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, - double c, const ColorSpinorField &z, ColorSpinorField &w); + void axpbypczw(real_t a, const ColorSpinorField &x, real_t b, const ColorSpinorField &y, + real_t c, const ColorSpinorField &z, ColorSpinorField &w); /** @brief Apply the operation y = a * x + b * y @@ -158,7 +158,7 @@ namespace quda { @param[in] y input vector @param[out] z output vector */ - void caxpby(const Complex &a, const ColorSpinorField &x, const Complex &b, ColorSpinorField &y); + void caxpby(const complex_t &a, const ColorSpinorField &x, const complex_t &b, ColorSpinorField &y); /** @brief Apply the operation y += a * x @@ -166,7 +166,7 @@ namespace quda { @param[in] x input vector @param[in] y update vector */ - void caxpy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y); + void caxpy(const complex_t &a, const ColorSpinorField &x, ColorSpinorField &y); /** @brief Apply the operation z = x + a * y + b * z @@ -176,8 +176,8 @@ namespace quda { @param[in] b scalar multiplier @param[in,out] z update vector */ - void cxpaypbz(const ColorSpinorField &x, const Complex &a, const ColorSpinorField &y, - const Complex &b, ColorSpinorField &z); + void cxpaypbz(const ColorSpinorField &x, const complex_t &a, const ColorSpinorField &y, + const complex_t &b, ColorSpinorField &z); /** @brief Apply the operation z += a * x + b * y, y-= b * w @@ -188,7 +188,7 @@ namespace quda { @param[in,out] z update vector @param[in] w input vector */ - void caxpbypzYmbw(const Complex &a, const ColorSpinorField &x, const Complex &b, ColorSpinorField &y, + void caxpbypzYmbw(const complex_t &a, const ColorSpinorField &x, const complex_t &b, ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w); /** @@ -199,8 +199,8 @@ namespace quda { @param[in] b scalar multiplier @param[in] z input vector */ - void caxpyBzpx(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, - const Complex &b, const ColorSpinorField &z); + void caxpyBzpx(const complex_t &a, ColorSpinorField &x, ColorSpinorField &y, + const complex_t &b, const ColorSpinorField &z); /** @brief Apply the operation y += a * x, z += b * x @@ -210,8 +210,8 @@ namespace quda { @param[in] b scalar multiplier @param[in,out] z update vector */ - void caxpyBxpz(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, - const Complex &b, ColorSpinorField &z); + void caxpyBxpz(const complex_t &a, const ColorSpinorField &x, ColorSpinorField &y, + const complex_t &b, ColorSpinorField &z); /** @brief Apply the operation y += a * b * x, x = a * x @@ -220,7 +220,7 @@ namespace quda { @param[in,out] x update vector @param[in,out] y update vector */ - void cabxpyAx(double a, const Complex &b, ColorSpinorField &x, ColorSpinorField &y); + void cabxpyAx(real_t a, const complex_t &b, ColorSpinorField &x, ColorSpinorField &y); /** @brief Apply the operation y += a * x, x -= a * z @@ -229,7 +229,7 @@ namespace quda { @param[in,out] y update vector @param[in] z input vector */ - void caxpyXmaz(const Complex &a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z); + void caxpyXmaz(const complex_t &a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z); /** @brief Apply the operation y += a * x, x = x - a * z. Special @@ -240,7 +240,7 @@ namespace quda { @param[in,out] y update vector @param[in] z input vector */ - void caxpyXmazMR(const double &a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z); + void caxpyXmazMR(const real_t &a, ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z); /** @brief Apply the operation y += a * w, z -= a * x, w = z + b * w @@ -251,7 +251,7 @@ namespace quda { @param[in,out] z update vector @param[in,out] w update vector */ - void tripleCGUpdate(double a, double b, const ColorSpinorField &x, + void tripleCGUpdate(real_t a, real_t b, const ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w); // reduction kernels - defined in reduce_quda.cu @@ -260,7 +260,7 @@ namespace quda { @brief Compute the maximum absolute real element of a field @param[in] a The field we are reducing */ - double max(const ColorSpinorField &x); + real_t max(const ColorSpinorField &x); /** @brief Compute the maximum real-valued deviation between two @@ -268,19 +268,19 @@ namespace quda { @param[in] x The field we want to compare @param[in] y The reference field to which we are comparing against */ - array max_deviation(const ColorSpinorField &x, const ColorSpinorField &y); + array max_deviation(const ColorSpinorField &x, const ColorSpinorField &y); /** @brief Compute the L1 norm of a field @param[in] x The field we are reducing */ - double norm1(const ColorSpinorField &x); + real_t norm1(const ColorSpinorField &x); /** @brief Compute the L2 norm (||x||^2) of a field @param[in] x The field we are reducing */ - double norm2(const ColorSpinorField &x); + real_t norm2(const ColorSpinorField &x); /** @brief Compute y += a * x and then (x, y) @@ -288,14 +288,14 @@ namespace quda { @param[in] x input vector @param[in,out] y update vector */ - double axpyReDot(double a, const ColorSpinorField &x, ColorSpinorField &y); + real_t axpyReDot(real_t a, const ColorSpinorField &x, ColorSpinorField &y); /** @brief Compute the real-valued inner product (x, y) @param[in] x input vector @param[in] y input vector */ - double reDotProduct(const ColorSpinorField &x, const ColorSpinorField &y); + real_t reDotProduct(const ColorSpinorField &x, const ColorSpinorField &y); /** @brief Compute z = a * x + b * y and then ||z||^2 @@ -305,7 +305,7 @@ namespace quda { @param[in] y input vector @param[in,out] z update vector */ - double axpbyzNorm(double a, const ColorSpinorField &x, double b, const ColorSpinorField &y, ColorSpinorField &z); + real_t axpbyzNorm(real_t a, const ColorSpinorField &x, real_t b, const ColorSpinorField &y, ColorSpinorField &z); /** @brief Compute y += a * x and then ||y||^2 @@ -313,38 +313,38 @@ namespace quda { @param[in] x input vector @param[in,out] y update vector */ - inline double axpyNorm(double a, const ColorSpinorField &x, ColorSpinorField &y) { return axpbyzNorm(a, x, 1.0, y, y); } + inline real_t axpyNorm(real_t a, const ColorSpinorField &x, ColorSpinorField &y) { return axpbyzNorm(a, x, 1.0, y, y); } /** @brief Compute y -= x and then ||y||^2 @param[in] x input vector @param[in,out] y update vector */ - inline double xmyNorm(const ColorSpinorField &x, ColorSpinorField &y) { return axpbyzNorm(1.0, x, -1.0, y, y); } + inline real_t xmyNorm(const ColorSpinorField &x, ColorSpinorField &y) { return axpbyzNorm(1.0, x, -1.0, y, y); } /** @brief Compute the complex-valued inner product (x, y) @param[in] x input vector @param[in] y input vector */ - Complex cDotProduct(const ColorSpinorField &x, const ColorSpinorField &y); + complex_t cDotProduct(const ColorSpinorField &x, const ColorSpinorField &y); /** @brief Return complex-valued inner product (x,y), ||x||^2 and ||y||^2 @param[in] x input vector @param[in] y input vector */ - double4 cDotProductNormAB(const ColorSpinorField &x, const ColorSpinorField &y); + array cDotProductNormAB(const ColorSpinorField &x, const ColorSpinorField &y); /** @brief Return complex-valued inner product (x,y) and ||x||^2 @param[in] x input vector @param[in] y input vector */ - inline double3 cDotProductNormA(const ColorSpinorField &x, const ColorSpinorField &y) + inline array cDotProductNormA(const ColorSpinorField &x, const ColorSpinorField &y) { auto a4 = cDotProductNormAB(x, y); - return make_double3(a4.x, a4.y, a4.z); + return {a4[0], a4[1], a4[2]}; } /** @@ -352,10 +352,10 @@ namespace quda { @param[in] x input vector @param[in] y input vector */ - inline double3 cDotProductNormB(const ColorSpinorField &x, const ColorSpinorField &y) + inline array cDotProductNormB(const ColorSpinorField &x, const ColorSpinorField &y) { auto a4 = cDotProductNormAB(x, y); - return make_double3(a4.x, a4.y, a4.w); + return {a4[0], a4[1], a4[2]}; } /** @@ -369,7 +369,7 @@ namespace quda { @param[in] w input vector @param[in] v input vector */ - double3 caxpbypzYmbwcDotProductUYNormY(const Complex &a, const ColorSpinorField &x, const Complex &b, + array caxpbypzYmbwcDotProductUYNormY(const complex_t &a, const ColorSpinorField &x, const complex_t &b, ColorSpinorField &y, ColorSpinorField &z, const ColorSpinorField &w, const ColorSpinorField &u); @@ -379,7 +379,7 @@ namespace quda { @param[in] x input vector @param[in,out] y update vector */ - double caxpyNorm(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y); + real_t caxpyNorm(const complex_t &a, const ColorSpinorField &x, ColorSpinorField &y); /** @brief Compute z = a * b * x + y, x = a * x, and then ||x||^2 @@ -389,7 +389,7 @@ namespace quda { @param[in] y input vector @param[in,out] z update vector */ - double cabxpyzAxNorm(double a, const Complex &b, ColorSpinorField &x, const ColorSpinorField &y, + real_t cabxpyzAxNorm(real_t a, const complex_t &b, ColorSpinorField &x, const ColorSpinorField &y, ColorSpinorField &z); /** @@ -399,7 +399,7 @@ namespace quda { @param[in,out] y update vector @param[in] z input vector */ - Complex caxpyDotzy(const Complex &a, const ColorSpinorField &x, ColorSpinorField &y, + complex_t caxpyDotzy(const complex_t &a, const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &z); /** @@ -409,7 +409,7 @@ namespace quda { @param[in] x input vector @param[in,out] y update vector */ - double2 axpyCGNorm(double a, const ColorSpinorField &x, ColorSpinorField &y); + array axpyCGNorm(real_t a, const ColorSpinorField &x, ColorSpinorField &y); /** @brief Computes ||x||^2, ||r||^2 and the MILC/FNAL heavy quark @@ -417,7 +417,7 @@ namespace quda { @param[in] x input vector @param[in] r input vector (residual vector) */ - double3 HeavyQuarkResidualNorm(const ColorSpinorField &x, const ColorSpinorField &r); + array HeavyQuarkResidualNorm(const ColorSpinorField &x, const ColorSpinorField &r); /** @brief Computes y += x, ||y||^2, ||r||^2 and the MILC/FNAL heavy quark @@ -426,7 +426,7 @@ namespace quda { @param[in,out] y update vector @param[in] r input vector (residual vector) */ - double3 xpyHeavyQuarkResidualNorm(const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &r); + array xpyHeavyQuarkResidualNorm(const ColorSpinorField &x, ColorSpinorField &y, const ColorSpinorField &r); /** @brief Computes ||x||^2, ||y||^2, and real-valued inner product (y, z) @@ -434,7 +434,7 @@ namespace quda { @param[in] y input vector @param[in] z input vector */ - double3 tripleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z); + array tripleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z); /** @brief Computes ||x||^2, ||y||^2, the real-valued inner product (y, z), and ||z||^2 @@ -442,7 +442,7 @@ namespace quda { @param[in] y input vector @param[in] z input vector */ - double4 quadrupleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z); + array quadrupleCGReduction(const ColorSpinorField &x, const ColorSpinorField &y, const ColorSpinorField &z); /** @brief Computes z = x, w = y, x += a * y, y -= a * v and ||y||^2 @@ -453,7 +453,7 @@ namespace quda { @param[in,out] w update vector @param[in] v input vector */ - double quadrupleCG3InitNorm(double a, ColorSpinorField &x, ColorSpinorField &y, + real_t quadrupleCG3InitNorm(real_t a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, const ColorSpinorField &v); /** @@ -468,7 +468,7 @@ namespace quda { @param[in,out] w update vector @param[in] v input vector */ - double quadrupleCG3UpdateNorm(double a, double b, ColorSpinorField &x, ColorSpinorField &y, + real_t quadrupleCG3UpdateNorm(real_t a, real_t b, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z, ColorSpinorField &w, const ColorSpinorField &v); // multi-blas kernels - defined in multi_blas.cu @@ -477,7 +477,7 @@ namespace quda { @brief Compute the block "axpy" with over the set of ColorSpinorFields. E.g., it computes y = x * a + y The dimensions of a can be rectangular, e.g., the width of x and y need not be same. - @tparam T The type of a coefficients (double or Complex) + @tparam T The type of a coefficients (real_t or complex_t) @param a[in] Matrix of real coefficients @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields @@ -493,7 +493,7 @@ namespace quda { Where 'a' must be a square, upper triangular matrix. - @tparam T The type of a coefficients (double or Complex) + @tparam T The type of a coefficients (real_t or complex_t) @param a[in] Matrix of coefficients @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields @@ -509,7 +509,7 @@ namespace quda { Where 'a' must be a square, lower triangular matrix. - @tparam T The type of a coefficients (double or Complex) + @tparam T The type of a coefficients (real_t or complex_t) @param a[in] Matrix of coefficients @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields @@ -530,7 +530,7 @@ namespace quda { @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields */ - void caxpy(const std::vector &a, cvector_ref &x, cvector_ref &y); + void caxpy(const std::vector &a, cvector_ref &x, cvector_ref &y); /** @brief Compute the block "caxpy_U" with over the set of @@ -544,7 +544,7 @@ namespace quda { @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields */ - void caxpy_U(const std::vector &a, cvector_ref &x, cvector_ref &y); + void caxpy_U(const std::vector &a, cvector_ref &x, cvector_ref &y); /** @brief Compute the block "caxpy_L" with over the set of @@ -558,7 +558,7 @@ namespace quda { @param x[in] vector of input ColorSpinorFields @param y[in,out] vector of input/output ColorSpinorFields */ - void caxpy_L(const std::vector &a, cvector_ref &x, cvector_ref &y); + void caxpy_L(const std::vector &a, cvector_ref &x, cvector_ref &y); /** @brief Compute the block "axpyz" with over the set of @@ -575,7 +575,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void axpyz(const std::vector &a, cvector_ref &x, + void axpyz(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -591,7 +591,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void axpyz_U(const std::vector &a, cvector_ref &x, + void axpyz_U(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -607,7 +607,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void axpyz_L(const std::vector &a, cvector_ref &x, + void axpyz_L(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -625,7 +625,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void caxpyz(const std::vector &a, cvector_ref &x, + void caxpyz(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -641,7 +641,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void caxpyz_U(const std::vector &a, cvector_ref &x, + void caxpyz_U(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -657,7 +657,7 @@ namespace quda { @param y[in] vector of input ColorSpinorFields @param z[out] vector of output ColorSpinorFields */ - void caxpyz_L(const std::vector &a, cvector_ref &x, + void caxpyz_L(const std::vector &a, cvector_ref &x, cvector_ref &y, cvector_ref &z); /** @@ -678,9 +678,9 @@ namespace quda { @param z[in] input ColorSpinorField @param c[in] Array of coefficients */ - void axpyBzpcx(const std::vector &a, cvector_ref &x, - cvector_ref &y, const std::vector &b, ColorSpinorField &z, - const std::vector &c); + void axpyBzpcx(const std::vector &a, cvector_ref &x, + cvector_ref &y, const std::vector &b, ColorSpinorField &z, + const std::vector &c); /** @brief Compute the vectorized "caxpyBxpz" over the set of @@ -699,8 +699,8 @@ namespace quda { @param b[in] Array of coefficients @param z[in,out] input ColorSpinorField */ - void caxpyBxpz(const std::vector &a, cvector_ref &x, ColorSpinorField &y, - const std::vector &b, ColorSpinorField &z); + void caxpyBxpz(const std::vector &a, cvector_ref &x, ColorSpinorField &y, + const std::vector &b, ColorSpinorField &z); // multi-reduce kernels - defined in multi_reduce.cu @@ -711,7 +711,7 @@ namespace quda { @param a[in] set of input ColorSpinorFields @param b[in] set of input ColorSpinorFields */ - void reDotProduct(std::vector &result, cvector_ref &a, + void reDotProduct(std::vector &result, cvector_ref &a, cvector_ref &b); /** @@ -721,7 +721,7 @@ namespace quda { @param a[in] set of input ColorSpinorFields @param b[in] set of input ColorSpinorFields */ - void cDotProduct(std::vector &result, cvector_ref &a, + void cDotProduct(std::vector &result, cvector_ref &a, cvector_ref &b); /** @@ -734,7 +734,7 @@ namespace quda { @param a[in] set of input ColorSpinorFields @param b[in] set of input ColorSpinorFields */ - void hDotProduct(std::vector &result, cvector_ref &a, + void hDotProduct(std::vector &result, cvector_ref &a, cvector_ref &b); /** @@ -749,47 +749,47 @@ namespace quda { @param a[in] set of input ColorSpinorFields @param b[in] set of input ColorSpinorFields */ - void hDotProduct_Anorm(std::vector &result, cvector_ref &a, + void hDotProduct_Anorm(std::vector &result, cvector_ref &a, cvector_ref &b); // compatibility wrappers until we switch to // std::vector and // std::vector> more broadly - void axpy(const double *a, std::vector &x, std::vector &y); - void axpy(const double *a, ColorSpinorField &x, ColorSpinorField &y); - void axpy_U(const double *a, std::vector &x, std::vector &y); - void axpy_U(const double *a, ColorSpinorField &x, ColorSpinorField &y); - void axpy_L(const double *a, std::vector &x, std::vector &y); - void axpy_L(const double *a, ColorSpinorField &x, ColorSpinorField &y); - void caxpy(const Complex *a, std::vector &x, std::vector &y); - void caxpy(const Complex *a, ColorSpinorField &x, ColorSpinorField &y); - void caxpy_U(const Complex *a, std::vector &x, std::vector &y); - void caxpy_U(const Complex *a, ColorSpinorField &x, ColorSpinorField &y); - void caxpy_L(const Complex *a, std::vector &x, std::vector &y); - void caxpy_L(const Complex *a, ColorSpinorField &x, ColorSpinorField &y); - void axpyz(const double *a, std::vector &x, std::vector &y, + void axpy(const real_t *a, std::vector &x, std::vector &y); + void axpy(const real_t *a, ColorSpinorField &x, ColorSpinorField &y); + void axpy_U(const real_t *a, std::vector &x, std::vector &y); + void axpy_U(const real_t *a, ColorSpinorField &x, ColorSpinorField &y); + void axpy_L(const real_t *a, std::vector &x, std::vector &y); + void axpy_L(const real_t *a, ColorSpinorField &x, ColorSpinorField &y); + void caxpy(const complex_t *a, std::vector &x, std::vector &y); + void caxpy(const complex_t *a, ColorSpinorField &x, ColorSpinorField &y); + void caxpy_U(const complex_t *a, std::vector &x, std::vector &y); + void caxpy_U(const complex_t *a, ColorSpinorField &x, ColorSpinorField &y); + void caxpy_L(const complex_t *a, std::vector &x, std::vector &y); + void caxpy_L(const complex_t *a, ColorSpinorField &x, ColorSpinorField &y); + void axpyz(const real_t *a, std::vector &x, std::vector &y, std::vector &z); - void axpyz(const double *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z); - void axpyz_U(const double *a, std::vector &x, std::vector &y, + void axpyz(const real_t *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z); + void axpyz_U(const real_t *a, std::vector &x, std::vector &y, std::vector &z); - void axpyz_L(const double *a, std::vector &x, std::vector &y, + void axpyz_L(const real_t *a, std::vector &x, std::vector &y, std::vector &z); - void caxpyz(const Complex *a, std::vector &x, std::vector &y, + void caxpyz(const complex_t *a, std::vector &x, std::vector &y, std::vector &z); - void caxpyz(const Complex *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z); - void caxpyz_U(const Complex *a, std::vector &x, std::vector &y, + void caxpyz(const complex_t *a, ColorSpinorField &x, ColorSpinorField &y, ColorSpinorField &z); + void caxpyz_U(const complex_t *a, std::vector &x, std::vector &y, std::vector &z); - void caxpyz_L(const Complex *a, std::vector &x, std::vector &y, + void caxpyz_L(const complex_t *a, std::vector &x, std::vector &y, std::vector &z); - void axpyBzpcx(const double *a, std::vector &x, std::vector &y, - const double *b, ColorSpinorField &z, const double *c); - void caxpyBxpz(const Complex *a_, std::vector &x_, ColorSpinorField &y_, const Complex *b_, + void axpyBzpcx(const real_t *a, std::vector &x, std::vector &y, + const real_t *b, ColorSpinorField &z, const real_t *c); + void caxpyBxpz(const complex_t *a_, std::vector &x_, ColorSpinorField &y_, const complex_t *b_, ColorSpinorField &z_); - void reDotProduct(double *result, std::vector &a, std::vector &b); - void cDotProduct(Complex *result, std::vector &a, std::vector &b); - void hDotProduct(Complex *result, std::vector &a, std::vector &b); + void reDotProduct(real_t *result, std::vector &a, std::vector &b); + void cDotProduct(complex_t *result, std::vector &a, std::vector &b); + void hDotProduct(complex_t *result, std::vector &a, std::vector &b); } // namespace blas diff --git a/include/clover_field.h b/include/clover_field.h index 380a399492..0e34f9b6e7 100644 --- a/include/clover_field.h +++ b/include/clover_field.h @@ -98,13 +98,13 @@ namespace quda { bool inverse = true; /** Whether to create the inverse clover field */ void *clover = nullptr; /** Pointer to the clover field */ void *cloverInv = nullptr; /** Pointer to the clover inverse field */ - double csw = 0.0; /** C_sw clover coefficient */ - double coeff = 0.0; /** Overall clover coefficient */ + real_t csw = 0.0; /** C_sw clover coefficient */ + real_t coeff = 0.0; /** Overall clover coefficient */ QudaTwistFlavorType twist_flavor = QUDA_TWIST_INVALID; /** Twisted-mass flavor type */ bool twisted = false; /** Whether to create twisted mass clover */ - double mu2 = 0.0; /** Chiral twisted mass term */ - double epsilon2 = 0.0; /** Flavor twisted mass term */ - double rho = 0.0; /** Hasenbusch rho term */ + real_t mu2 = 0.0; /** Chiral twisted mass term */ + real_t epsilon2 = 0.0; /** Flavor twisted mass term */ + real_t rho = 0.0; /** Hasenbusch rho term */ QudaCloverFieldOrder order = QUDA_INVALID_CLOVER_ORDER; /** Field order */ QudaFieldCreate create = QUDA_INVALID_FIELD_CREATE; /** Creation type */ @@ -182,20 +182,20 @@ namespace quda { quda_ptr cloverInv = {}; bool inverse = false; - double diagonal = 0.0; - array max = {}; + real_t diagonal = 0.0; + array max = {}; - double csw = 0.0; - double coeff = 0.0; + real_t csw = 0.0; + real_t coeff = 0.0; QudaTwistFlavorType twist_flavor = QUDA_TWIST_INVALID; - double mu2 = 0.0; // chiral twisted mass squared - double epsilon2 = 0.0; // flavour twisted mass squared - double rho = 0.0; + real_t mu2 = 0.0; // chiral twisted mass squared + real_t epsilon2 = 0.0; // flavour twisted mass squared + real_t rho = 0.0; QudaCloverFieldOrder order = QUDA_INVALID_CLOVER_ORDER; QudaFieldCreate create = QUDA_INVALID_FIELD_CREATE; - mutable array trlog = {}; + mutable array trlog = {}; /** @brief Set the vol_string and aux_string for use in tuning @@ -230,12 +230,12 @@ namespace quda { /** @return diagonal scaling factor applied to the identity */ - double Diagonal() const { return diagonal; } + real_t Diagonal() const { return diagonal; } /** @return set diagonal scaling factor applied to the identity */ - void Diagonal(double diagonal) { this->diagonal = diagonal; } + void Diagonal(real_t diagonal) { this->diagonal = diagonal; } /** @return max element in the clover field for fixed-point scaling @@ -296,12 +296,12 @@ namespace quda { /** @return Csw coefficient (does not include kappa) */ - double Csw() const { return csw; } + real_t Csw() const { return csw; } /** @return Clover coefficient (explicitly includes kappa) */ - double Coeff() const { return coeff; } + real_t Coeff() const { return coeff; } /** @return If the clover field is associated with twisted-clover fermions and which flavor type thereof @@ -311,24 +311,24 @@ namespace quda { /** @return mu^2 factor baked into inverse clover field (for twisted-clover inverse) */ - double Mu2() const { return mu2; } + real_t Mu2() const { return mu2; } /** @return epsilon^2 factor baked into inverse clover field (for non-deg twisted-clover inverse) */ - double Epsilon2() const { return epsilon2; } + real_t Epsilon2() const { return epsilon2; } /** @return rho factor backed into the clover field, (for real diagonal additive Hasenbusch), e.g., A + rho */ - double Rho() const { return rho; } + real_t Rho() const { return rho; } /** @brief Bakes in the rho factor into the clover field, (for real diagonal additive Hasenbusch), e.g., A + rho */ - void setRho(double rho); + void setRho(real_t rho) { this->rho = rho; } /** @brief Copy into this CloverField from CloverField src @@ -348,25 +348,25 @@ namespace quda { @brief Compute the L1 norm of the field @return L1 norm */ - double norm1(bool inverse = false) const; + real_t norm1(bool inverse = false) const; /** @brief Compute the L2 norm squared of the field @return L2 norm squared */ - double norm2(bool inverse = false) const; + real_t norm2(bool inverse = false) const; /** @brief Compute the absolute maximum of the field (Linfinity norm) @return Absolute maximum value */ - double abs_max(bool inverse = false) const; + real_t abs_max(bool inverse = false) const; /** @brief Compute the absolute minimum of the field @return Absolute minimum value */ - double abs_min(bool inverse = false) const; + real_t abs_min(bool inverse = false) const; /** @brief Backs up the CloverField @@ -421,7 +421,7 @@ namespace quda { @param a The clover field that we want the norm of @return The L1 norm of the gauge field */ - double norm1(const CloverField &u, bool inverse=false); + real_t norm1(const CloverField &u, bool inverse=false); /** This is a debugging function, where we cast a clover field into a @@ -429,7 +429,7 @@ namespace quda { @param a The clover field that we want the norm of @return The L2 norm squared of the gauge field */ - double norm2(const CloverField &a, bool inverse=false); + real_t norm2(const CloverField &a, bool inverse=false); /** @brief Driver for computing the clover field from the field @@ -438,7 +438,7 @@ namespace quda { @param[in] fmunu Field strength tensor @param[in] coefft Clover coefficient */ - void computeClover(CloverField &clover, const GaugeField &fmunu, double coeff); + void computeClover(CloverField &clover, const GaugeField &fmunu, real_t coeff); /** @brief This generic function is used for copying the clover field where @@ -482,7 +482,7 @@ namespace quda { */ void computeCloverForce(GaugeField& force, const GaugeField& U, std::vector &x, std::vector &p, - std::vector &coeff); + std::vector &coeff); /** @brief Compute the outer product from the solver solution fields arising from the diagonal term of the fermion bilinear in @@ -496,7 +496,7 @@ namespace quda { void computeCloverSigmaOprod(GaugeField& oprod, std::vector &x, std::vector &p, - std::vector< std::vector > &coeff); + std::vector< std::vector > &coeff); /** @brief Compute the matrix tensor field necessary for the force calculation from the clover trace action. This computes a tensor field [mu,nu]. @@ -505,7 +505,7 @@ namespace quda { @param clover The input clover field @param coeff Scalar coefficient multiplying the result (e.g., stepsize) */ - void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, double coeff); + void computeCloverSigmaTrace(GaugeField &output, const CloverField &clover, real_t coeff); /** @brief Compute the derivative of the clover matrix in the direction @@ -518,7 +518,7 @@ namespace quda { @param coeff Multiplicative coefficient (e.g., clover coefficient) @param parity The field parity we are working on */ - void cloverDerivative(GaugeField &force, GaugeField &gauge, GaugeField &oprod, double coeff, QudaParity parity); + void cloverDerivative(GaugeField &force, GaugeField &gauge, GaugeField &oprod, real_t coeff, QudaParity parity); /** @brief This function is used for copying from a source clover field to a destination clover field diff --git a/include/clover_field_order.h b/include/clover_field_order.h index 65d5ef6cff..8ecac235fe 100644 --- a/include/clover_field_order.h +++ b/include/clover_field_order.h @@ -291,9 +291,9 @@ namespace quda { @tparam helper The helper functor which acts as the transformer in transform_reduce */ - template constexpr double transform_reduce(QudaFieldLocation, helper) const + template constexpr auto transform_reduce(QudaFieldLocation, helper) const { - return 0.0; + return real_t(0.0); } }; @@ -316,7 +316,7 @@ namespace quda { stride(A.VolumeCB()), offset_cb(A.Bytes() / (2 * sizeof(Float))), compressed_block_size(A.compressed_block_size()), - recon(A.Diagonal()) + recon(double(A.Diagonal())) { } @@ -361,7 +361,7 @@ namespace quda { in transform_reduce */ template - __host__ double transform_reduce(QudaFieldLocation location, helper h) const + auto transform_reduce(QudaFieldLocation location, helper h) const { // just use offset_cb, since factor of two from parity is equivalent to complexity return ::quda::transform_reduce(location, reinterpret_cast *>(a), offset_cb, h); @@ -440,7 +440,7 @@ namespace quda { in transform_reduce */ template - __host__ double transform_reduce(QudaFieldLocation location, helper h) const + auto transform_reduce(QudaFieldLocation location, helper h) const { return ::quda::transform_reduce(location, reinterpret_cast *>(a), offset_cb, h); } @@ -518,11 +518,10 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dummy for clover) * @return L1 norm */ - __host__ double norm1(int = -1, bool global = true) const + auto norm1(int = -1, bool global = true) const { commGlobalReductionPush(global); - double nrm1 - = accessor.scale() * accessor.template transform_reduce>(location, abs_()); + real_t nrm1 = real_t(accessor.scale() * accessor.template transform_reduce>(location, abs_())); commGlobalReductionPop(); return nrm1; } @@ -532,11 +531,11 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dummy for clover) * @return L1 norm */ - __host__ double norm2(int = -1, bool global = true) const + auto norm2(int = -1, bool global = true) const { commGlobalReductionPush(global); - double nrm2 = accessor.scale() * accessor.scale() - * accessor.template transform_reduce>(location, square_()); + real_t nrm2 = real_t(accessor.scale() * accessor.scale() + * accessor.template transform_reduce>(location, square_())); commGlobalReductionPop(); return nrm2; } @@ -546,11 +545,10 @@ namespace quda { * @param[in] dim Which dimension we are taking the Linfinity norm of (dummy for clover) * @return Linfinity norm */ - __host__ double abs_max(int = -1, bool global = true) const + auto abs_max(int = -1, bool global = true) const { commGlobalReductionPush(global); - double absmax - = accessor.scale() * accessor.template transform_reduce>(location, abs_max_()); + real_t absmax = real_t(accessor.scale() * accessor.template transform_reduce>(location, abs_max_())); commGlobalReductionPop(); return absmax; } @@ -560,11 +558,10 @@ namespace quda { * @param[in] dim Which dimension we are taking the minimum abs of (dummy for clover) * @return Minimum norm */ - __host__ double abs_min(int = -1, bool global = true) const + auto abs_min(int = -1, bool global = true) const { commGlobalReductionPush(global); - double absmin - = accessor.scale() * accessor.template transform_reduce>(location, abs_min_()); + real_t absmin = real_t(accessor.scale() * accessor.template transform_reduce>(location, abs_min_())); commGlobalReductionPop(); return absmin; } @@ -618,7 +615,7 @@ namespace quda { void *backup_h; //! host memory for backing up the field when tuning FloatNOrder(const CloverField &clover, bool is_inverse, Float *clover_ = nullptr) : - recon(clover.Diagonal()), + recon(double(clover.Diagonal())), nrm(clover.max_element(is_inverse) / (2 * (isFixed::value ? fixedMaxValue::value : 1))), // factor of two in normalization nrm_inv(1.0 / nrm), diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 87f10bce60..27dc9ba241 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -703,7 +703,7 @@ namespace quda nParity(field.SiteSubset()), ghostAccessor(field, nFace) { resetGhost(ghost_ ? ghost_ : field.Ghost()); - resetScale(field.Scale()); + resetScale(double(field.Scale())); } GhostOrder &operator=(const GhostOrder &) = default; @@ -803,8 +803,8 @@ namespace quda commGlobalReductionPush(global); Float scale_inv = 1.0; if constexpr (fixed && !block_float_ghost) scale_inv = ghost.scale_inv; - auto nrm2 = transform_reduce>(dim, field.Location(), field.SiteSubset(), - square_(scale_inv)); + real_t nrm2 = real_t(transform_reduce>(dim, field.Location(), field.SiteSubset(), + square_(scale_inv))); commGlobalReductionPop(); return nrm2; } @@ -821,8 +821,8 @@ namespace quda commGlobalReductionPush(global); Float scale_inv = 1.0; if constexpr (fixed && !block_float_ghost) scale_inv = ghost.scale_inv; - auto absmax = transform_reduce>(field.Location(), field.SiteSubset(), - abs_max_(scale_inv)); + real_t absmax = real_t(transform_reduce>(field.Location(), field.SiteSubset(), + abs_max_(scale_inv))); commGlobalReductionPop(); return absmax; } @@ -878,7 +878,7 @@ namespace quda GhostOrder(field, nFace, ghost_), volumeCB(field.VolumeCB()), accessor(field) { v.v = v_ ? static_cast *>(const_cast(v_)) : field.data *>(); - resetScale(field.Scale()); + resetScale(double(field.Scale())); if constexpr (fixed && block_float) { if constexpr (nColor == 3 && nSpin == 1 && nVec == 1 && order == 2) @@ -1022,8 +1022,8 @@ namespace quda commGlobalReductionPush(global); Float scale_inv = 1.0; if constexpr (fixed && !block_float) scale_inv = v.scale_inv; - auto nrm2 - = transform_reduce>(field.Location(), field.SiteSubset(), square_(scale_inv)); + real_t nrm2 = real_t(transform_reduce>(field.Location(), field.SiteSubset(), + square_(scale_inv))); commGlobalReductionPop(); return nrm2; } @@ -1039,8 +1039,8 @@ namespace quda commGlobalReductionPush(global); Float scale_inv = 1.0; if constexpr (fixed && !block_float) scale_inv = v.scale_inv; - auto absmax = transform_reduce>(field.Location(), field.SiteSubset(), - abs_max_(scale_inv)); + auto absmax = real_t(transform_reduce>(field.Location(), field.SiteSubset(), + abs_max_(scale_inv))); commGlobalReductionPop(); return absmax; } diff --git a/include/communicator_quda.h b/include/communicator_quda.h index 0dc22fdfab..c9c59c3dbd 100644 --- a/include/communicator_quda.h +++ b/include/communicator_quda.h @@ -14,6 +14,7 @@ #include #include #include +#include "reducer.h" #if defined(MPI_COMMS) || defined(QMP_COMMS) #include @@ -733,21 +734,19 @@ namespace quda int comm_query(MsgHandle *mh); - template T deterministic_reduce(T *array, int n) + template T deterministic_sum_reduce(T *array, int n) { std::sort(array, array + n); // sort reduction into ascending order for deterministic reduction - return std::accumulate(array, array + n, 0.0); + return std::accumulate(array, array + n, T(0.0)); } - void comm_allreduce_sum_array(double *data, size_t size); + template void comm_allreduce_sum_array(T *data, size_t size); - void comm_allreduce_sum(size_t &a); - - void comm_allreduce_max_array(double *data, size_t size); + template void comm_allreduce_max_array(T *data, size_t size); - void comm_allreduce_max_array(deviation_t *data, size_t size); + void comm_allreduce_sum(size_t &a); - void comm_allreduce_min_array(double *data, size_t size); + template void comm_allreduce_min_array(T *data, size_t size); void comm_allreduce_int(int &data); diff --git a/include/complex_quda.h b/include/complex_quda.h index 66b5eee609..28bf74be05 100644 --- a/include/complex_quda.h +++ b/include/complex_quda.h @@ -26,6 +26,7 @@ #include #include #include // for double2 / float2 +#include "dbldbl.h" namespace quda { namespace gauge { @@ -124,6 +125,7 @@ namespace quda __host__ __device__ inline float conj(float x) { return x; } __host__ __device__ inline double conj(double x) { return x; } + __host__ __device__ inline doubledouble conj(doubledouble x) { return x; } template struct complex; //template <> struct complex; @@ -427,10 +429,10 @@ struct complex return *this; } - __host__ __device__ inline ValueType real() const; - __host__ __device__ inline ValueType imag() const; - __host__ __device__ inline void real(ValueType); - __host__ __device__ inline void imag(ValueType); + constexpr ValueType real() const; + constexpr ValueType imag() const; + __host__ __device__ inline void real(ValueType); + __host__ __device__ inline void imag(ValueType); }; template <> struct complex : public float2 { @@ -617,25 +619,25 @@ struct complex : public short2 real(real() + z.real()); imag(imag() + z.imag()); return *this; - } - - __host__ __device__ inline complex &operator-=(const complex &z) - { - real(real()-z.real()); - imag(imag()-z.imag()); - return *this; - } + } - constexpr short real() const { return x; } - constexpr short imag() const { return y; } - __host__ __device__ inline void real(short re) { x = re; } - __host__ __device__ inline void imag(short im) { y = im; } + __host__ __device__ inline complex &operator-=(const complex &z) + { + real(real() - z.real()); + imag(imag() - z.imag()); + return *this; + } - // cast operators - template inline __host__ __device__ operator complex() const - { - return complex(static_cast(real()), static_cast(imag())); } + constexpr short real() const { return x; } + constexpr short imag() const { return y; } + __host__ __device__ inline void real(short re) { x = re; } + __host__ __device__ inline void imag(short im) { y = im; } + // cast operators + template inline __host__ __device__ operator complex() const + { + return complex(static_cast(real()), static_cast(imag())); + } }; template<> @@ -653,25 +655,62 @@ struct complex : public int2 real(real() + z.real()); imag(imag() + z.imag()); return *this; - } + } - __host__ __device__ inline complex &operator-=(const complex &z) - { - real(real()-z.real()); - imag(imag()-z.imag()); - return *this; - } + __host__ __device__ inline complex &operator-=(const complex &z) + { + real(real() - z.real()); + imag(imag() - z.imag()); + return *this; + } - constexpr int real() const { return x; } - constexpr int imag() const { return y; } - __host__ __device__ inline void real(int re) { x = re; } - __host__ __device__ inline void imag(int im) { y = im; } + constexpr int real() const { return x; } + constexpr int imag() const { return y; } + __host__ __device__ inline void real(int re) { x = re; } + __host__ __device__ inline void imag(int im) { y = im; } - // cast operators - template inline __host__ __device__ operator complex() const - { - return complex(static_cast(real()), static_cast(imag())); } + // cast operators + template inline __host__ __device__ operator complex() const + { + return complex(static_cast(real()), static_cast(imag())); + } +}; + +template <> struct complex : public doubledouble2 { +public: + typedef doubledouble value_type; + + complex() = default; + + constexpr complex(const doubledouble &re, const doubledouble &im = doubledouble()) : + doubledouble2 {re, im} + { + } + + __host__ __device__ inline complex &operator+=(const complex &z) + { + real(real() + z.real()); + imag(imag() + z.imag()); + return *this; + } + __host__ __device__ inline complex &operator-=(const complex &z) + { + real(real() - z.real()); + imag(imag() - z.imag()); + return *this; + } + + constexpr doubledouble real() const { return x; } + constexpr doubledouble imag() const { return y; } + __host__ __device__ inline void real(doubledouble re) { x = re; } + __host__ __device__ inline void imag(doubledouble im) { y = im; } + + // cast operators + template inline __host__ __device__ operator complex() const + { + return complex(static_cast(real()), static_cast(imag())); + } }; // Binary arithmetic operations @@ -1169,9 +1208,9 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); { complex w; w.x = x.real() * y.real(); - w.x -= x.imag() * y.imag(); + w.x = fma(-x.imag(), y.imag(), w.x); w.y = x.imag() * y.real(); - w.y += x.real() * y.imag(); + w.y = fma(x.real(), y.imag(), w.y); return w; } @@ -1179,10 +1218,10 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); __host__ __device__ inline complex cmac(const complex &x, const complex &y, const complex &z) { complex w = z; - w.x += x.real() * y.real(); - w.x -= x.imag() * y.imag(); - w.y += x.imag() * y.real(); - w.y += x.real() * y.imag(); + w.x = fma( x.real(), y.real(), w.x); + w.x = fma(-x.imag(), y.imag(), w.x); + w.y = fma( x.imag(), y.real(), w.y); + w.y = fma( x.real(), y.imag(), w.y); return w; } @@ -1197,10 +1236,10 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real()); complex X = x; complex Y = y; complex Z = z; - Z.real(Z.real() + X.real() * Y.real()); - Z.real(Z.real() - X.imag() * Y.imag()); - Z.imag(Z.imag() + X.imag() * Y.real()); - Z.imag(Z.imag() + X.real() * Y.imag()); + Z.real(fma( X.real(), Y.real(), Z.real())); + Z.real(fma(-X.imag(), Y.imag(), Z.real())); + Z.imag(fma( X.imag(), Y.real(), Z.imag())); + Z.imag(fma( X.real(), Y.imag(), Z.imag())); return Z; } diff --git a/include/dbldbl.h b/include/dbldbl.h index f4b198e373..084ba41fde 100644 --- a/include/dbldbl.h +++ b/include/dbldbl.h @@ -66,7 +66,7 @@ typedef double2 dbldbl; requirement. To create a double-double from two arbitrary double-precision numbers, use add_double_to_dbldbl(). */ -__device__ __forceinline__ dbldbl make_dbldbl (double head, double tail) +__device__ __host__ __forceinline__ dbldbl make_dbldbl (double head, double tail) { dbldbl z; z.x = tail; @@ -75,42 +75,42 @@ __device__ __forceinline__ dbldbl make_dbldbl (double head, double tail) } /* Return the head of a double-double number */ -__device__ __forceinline__ double get_dbldbl_head (dbldbl a) +__device__ __host__ __forceinline__ double get_dbldbl_head (dbldbl a) { return a.y; } /* Return the tail of a double-double number */ -__device__ __forceinline__ double get_dbldbl_tail (dbldbl a) +__device__ __host__ __forceinline__ double get_dbldbl_tail (dbldbl a) { return a.x; } /* Compute error-free sum of two unordered doubles. See Knuth, TAOCP vol. 2 */ -__device__ __forceinline__ dbldbl add_double_to_dbldbl (double a, double b) +__device__ __host__ __forceinline__ dbldbl add_double_to_dbldbl (double a, double b) { double t1, t2; dbldbl z; - z.y = __dadd_rn (a, b); - t1 = __dadd_rn (z.y, -a); - t2 = __dadd_rn (z.y, -t1); - t1 = __dadd_rn (b, -t1); - t2 = __dadd_rn (a, -t2); - z.x = __dadd_rn (t1, t2); + z.y = quda::dadd_rn (a, b); + t1 = quda::dadd_rn (z.y, -a); + t2 = quda::dadd_rn (z.y, -t1); + t1 = quda::dadd_rn (b, -t1); + t2 = quda::dadd_rn (a, -t2); + z.x = quda::dadd_rn (t1, t2); return z; } /* Compute error-free product of two doubles. Take full advantage of FMA */ -__device__ __forceinline__ dbldbl mul_double_to_dbldbl (double a, double b) +__device__ __host__ __forceinline__ dbldbl mul_double_to_dbldbl (double a, double b) { dbldbl z; - z.y = __dmul_rn (a, b); - z.x = __fma_rn (a, b, -z.y); + z.y = quda::dmul_rn (a, b); + z.x = quda::fma_rn (a, b, -z.y); return z; } /* Negate a double-double number, by separately negating head and tail */ -__device__ __forceinline__ dbldbl neg_dbldbl (dbldbl a) +__device__ __host__ __forceinline__ dbldbl neg_dbldbl (dbldbl a) { dbldbl z; z.y = -a.y; @@ -125,22 +125,22 @@ __device__ __forceinline__ dbldbl neg_dbldbl (dbldbl a) Floating-Point Numbers for GPU Computation. Retrieved on 7/12/2011 from http://andrewthall.org/papers/df64_qf128.pdf. */ -__device__ __forceinline__ dbldbl add_dbldbl (dbldbl a, dbldbl b) +__device__ __host__ __forceinline__ dbldbl add_dbldbl (dbldbl a, dbldbl b) { dbldbl z; double t1, t2, t3, t4, t5, e; - t1 = __dadd_rn (a.y, b.y); - t2 = __dadd_rn (t1, -a.y); - t3 = __dadd_rn (__dadd_rn (a.y, t2 - t1), __dadd_rn (b.y, -t2)); - t4 = __dadd_rn (a.x, b.x); - t2 = __dadd_rn (t4, -a.x); - t5 = __dadd_rn (__dadd_rn (a.x, t2 - t4), __dadd_rn (b.x, -t2)); - t3 = __dadd_rn (t3, t4); - t4 = __dadd_rn (t1, t3); - t3 = __dadd_rn (t1 - t4, t3); - t3 = __dadd_rn (t3, t5); - z.y = e = __dadd_rn (t4, t3); - z.x = __dadd_rn (t4 - e, t3); + t1 = quda::dadd_rn (a.y, b.y); + t2 = quda::dadd_rn (t1, -a.y); + t3 = quda::dadd_rn (quda::dadd_rn (a.y, t2 - t1), quda::dadd_rn (b.y, -t2)); + t4 = quda::dadd_rn (a.x, b.x); + t2 = quda::dadd_rn (t4, -a.x); + t5 = quda::dadd_rn (quda::dadd_rn (a.x, t2 - t4), quda::dadd_rn (b.x, -t2)); + t3 = quda::dadd_rn (t3, t4); + t4 = quda::dadd_rn (t1, t3); + t3 = quda::dadd_rn (t1 - t4, t3); + t3 = quda::dadd_rn (t3, t5); + z.y = e = quda::dadd_rn (t4, t3); + z.x = quda::dadd_rn (t4 - e, t3); return z; } @@ -151,22 +151,22 @@ __device__ __forceinline__ dbldbl add_dbldbl (dbldbl a, dbldbl b) Floating-Point Numbers for GPU Computation. Retrieved on 7/12/2011 from http://andrewthall.org/papers/df64_qf128.pdf. */ -__device__ __forceinline__ dbldbl sub_dbldbl (dbldbl a, dbldbl b) +__device__ __host__ __forceinline__ dbldbl sub_dbldbl (dbldbl a, dbldbl b) { dbldbl z; double t1, t2, t3, t4, t5, e; - t1 = __dadd_rn (a.y, -b.y); - t2 = __dadd_rn (t1, -a.y); - t3 = __dadd_rn (__dadd_rn (a.y, t2 - t1), - __dadd_rn (b.y, t2)); - t4 = __dadd_rn (a.x, -b.x); - t2 = __dadd_rn (t4, -a.x); - t5 = __dadd_rn (__dadd_rn (a.x, t2 - t4), - __dadd_rn (b.x, t2)); - t3 = __dadd_rn (t3, t4); - t4 = __dadd_rn (t1, t3); - t3 = __dadd_rn (t1 - t4, t3); - t3 = __dadd_rn (t3, t5); - z.y = e = __dadd_rn (t4, t3); - z.x = __dadd_rn (t4 - e, t3); + t1 = quda::dadd_rn (a.y, -b.y); + t2 = quda::dadd_rn (t1, -a.y); + t3 = quda::dadd_rn (quda::dadd_rn (a.y, t2 - t1), - quda::dadd_rn (b.y, t2)); + t4 = quda::dadd_rn (a.x, -b.x); + t2 = quda::dadd_rn (t4, -a.x); + t5 = quda::dadd_rn (quda::dadd_rn (a.x, t2 - t4), - quda::dadd_rn (b.x, t2)); + t3 = quda::dadd_rn (t3, t4); + t4 = quda::dadd_rn (t1, t3); + t3 = quda::dadd_rn (t1 - t4, t3); + t3 = quda::dadd_rn (t3, t5); + z.y = e = quda::dadd_rn (t4, t3); + z.x = quda::dadd_rn (t4 - e, t3); return z; } @@ -175,17 +175,17 @@ __device__ __forceinline__ dbldbl sub_dbldbl (dbldbl a, dbldbl b) relative error observed with 10 billion test cases was 5.238480533564479e-32 (~= 2**-103.9125). */ -__device__ __forceinline__ dbldbl mul_dbldbl (dbldbl a, dbldbl b) +__device__ __host__ __forceinline__ dbldbl mul_dbldbl (dbldbl a, dbldbl b) { dbldbl t, z; double e; - t.y = __dmul_rn (a.y, b.y); - t.x = __fma_rn (a.y, b.y, -t.y); - t.x = __fma_rn (a.x, b.x, t.x); - t.x = __fma_rn (a.y, b.x, t.x); - t.x = __fma_rn (a.x, b.y, t.x); - z.y = e = __dadd_rn (t.y, t.x); - z.x = __dadd_rn (t.y - e, t.x); + t.y = quda::dmul_rn (a.y, b.y); + t.x = quda::fma_rn (a.y, b.y, -t.y); + t.x = quda::fma_rn (a.x, b.x, t.x); + t.x = quda::fma_rn (a.y, b.x, t.x); + t.x = quda::fma_rn (a.x, b.y, t.x); + z.y = e = quda::dadd_rn (t.y, t.x); + z.x = quda::dadd_rn (t.y - e, t.x); return z; } @@ -197,22 +197,22 @@ __device__ __forceinline__ dbldbl mul_dbldbl (dbldbl a, dbldbl b) maximum relative error observed with 10 billion test cases was 1.0161322480099059e-31 (~= 2**-102.9566). */ -__device__ __forceinline__ dbldbl div_dbldbl (dbldbl a, dbldbl b) +__device__ __host__ __forceinline__ dbldbl div_dbldbl (dbldbl a, dbldbl b) { dbldbl t, z; double e, r; r = 1.0 / b.y; - t.y = __dmul_rn (a.y, r); - e = __fma_rn (b.y, -t.y, a.y); - t.y = __fma_rn (r, e, t.y); - t.x = __fma_rn (b.y, -t.y, a.y); - t.x = __dadd_rn (a.x, t.x); - t.x = __fma_rn (b.x, -t.y, t.x); - e = __dmul_rn (r, t.x); - t.x = __fma_rn (b.y, -e, t.x); - t.x = __fma_rn (r, t.x, e); - z.y = e = __dadd_rn (t.y, t.x); - z.x = __dadd_rn (t.y - e, t.x); + t.y = quda::dmul_rn (a.y, r); + e = quda::fma_rn (b.y, -t.y, a.y); + t.y = quda::fma_rn (r, e, t.y); + t.x = quda::fma_rn (b.y, -t.y, a.y); + t.x = quda::dadd_rn (a.x, t.x); + t.x = quda::fma_rn (b.x, -t.y, t.x); + e = quda::dmul_rn (r, t.x); + t.x = quda::fma_rn (b.y, -e, t.x); + t.x = quda::fma_rn (r, t.x, e); + z.y = e = quda::dadd_rn (t.y, t.x); + z.x = quda::dadd_rn (t.y - e, t.x); return z; } @@ -223,25 +223,25 @@ __device__ __forceinline__ dbldbl div_dbldbl (dbldbl a, dbldbl b) relative error observed with 10 billion test cases was 3.7564109505601846e-32 (~= 2**-104.3923). */ -__device__ __forceinline__ dbldbl sqrt_dbldbl (dbldbl a) +__device__ __host__ __forceinline__ dbldbl sqrt_dbldbl (dbldbl a) { dbldbl t, z; double e, y, s, r; r = quda::rsqrt(a.y); if (a.y == 0.0) r = 0.0; - y = __dmul_rn (a.y, r); - s = __fma_rn (y, -y, a.y); - r = __dmul_rn (0.5, r); - z.y = e = __dadd_rn (s, a.x); - z.x = __dadd_rn (s - e, a.x); - t.y = __dmul_rn (r, z.y); - t.x = __fma_rn (r, z.y, -t.y); - t.x = __fma_rn (r, z.x, t.x); - r = __dadd_rn (y, t.y); - s = __dadd_rn (y - r, t.y); - s = __dadd_rn (s, t.x); - z.y = e = __dadd_rn (r, s); - z.x = __dadd_rn (r - e, s); + y = quda::dmul_rn (a.y, r); + s = quda::fma_rn (y, -y, a.y); + r = quda::dmul_rn (0.5, r); + z.y = e = quda::dadd_rn (s, a.x); + z.x = quda::dadd_rn (s - e, a.x); + t.y = quda::dmul_rn (r, z.y); + t.x = quda::fma_rn (r, z.y, -t.y); + t.x = quda::fma_rn (r, z.x, t.x); + r = quda::dadd_rn (y, t.y); + s = quda::dadd_rn (y - r, t.y); + s = quda::dadd_rn (s, t.x); + z.y = e = quda::dadd_rn (r, s); + z.x = quda::dadd_rn (r - e, s); return z; } @@ -250,26 +250,26 @@ __device__ __forceinline__ dbldbl sqrt_dbldbl (dbldbl a) the maximum relative error observed with 10 billion test cases was 6.4937771666026349e-32 (~= 2**-103.6026) */ -__device__ __forceinline__ dbldbl rsqrt_dbldbl (dbldbl a) +__device__ __host__ __forceinline__ dbldbl rsqrt_dbldbl (dbldbl a) { dbldbl z; double r, s, e; r = quda::rsqrt(a.y); - e = __dmul_rn (a.y, r); - s = __fma_rn (e, -r, 1.0); - e = __fma_rn (a.y, r, -e); - s = __fma_rn (e, -r, s); - e = __dmul_rn (a.x, r); - s = __fma_rn (e, -r, s); + e = quda::dmul_rn (a.y, r); + s = quda::fma_rn (e, -r, 1.0); + e = quda::fma_rn (a.y, r, -e); + s = quda::fma_rn (e, -r, s); + e = quda::dmul_rn (a.x, r); + s = quda::fma_rn (e, -r, s); e = 0.5 * r; - z.y = __dmul_rn (e, s); - z.x = __fma_rn (e, s, -z.y); - s = __dadd_rn (r, z.y); - r = __dadd_rn (r, -s); - r = __dadd_rn (r, z.y); - r = __dadd_rn (r, z.x); - z.y = e = __dadd_rn (s, r); - z.x = __dadd_rn (s - e, r); + z.y = quda::dmul_rn (e, s); + z.x = quda::fma_rn (e, s, -z.y); + s = quda::dadd_rn (r, z.y); + r = quda::dadd_rn (r, -s); + r = quda::dadd_rn (r, z.y); + r = quda::dadd_rn (r, z.x); + z.y = e = quda::dadd_rn (s, r); + z.x = quda::dadd_rn (s - e, r); return z; } @@ -285,97 +285,162 @@ struct doubledouble { dbldbl a; - __device__ __host__ doubledouble() { a.x = 0.0; a.y = 0.0; } - __device__ __host__ doubledouble(const doubledouble &a) : a(a.a) { } - __device__ __host__ doubledouble(const dbldbl &a) : a(a) { } - __device__ __host__ doubledouble(const double &head, const double &tail) { a.y = head; a.x = tail; } - __device__ __host__ doubledouble(const double &head) { a.y = head; a.x = 0.0; } + doubledouble() = default; + constexpr doubledouble(const doubledouble &a) = default; + constexpr doubledouble(const dbldbl &a) : a(a) { } + constexpr doubledouble(const double &head, const double &tail) : a{tail, head} { } + constexpr doubledouble(const double &head) : a{0.0, head} { } __device__ __host__ doubledouble& operator=(const double &head) { this->a.y = head; this->a.x = 0.0; + return *this; } - __device__ doubledouble& operator+=(const doubledouble &a) { + __device__ __host__ doubledouble& operator+=(const doubledouble &a) { this->a = add_dbldbl(this->a, a.a); return *this; } - __device__ __host__ double head() const { return a.y; } - __device__ __host__ double tail() const { return a.x; } + __device__ __host__ doubledouble& operator-=(const doubledouble &a) { + this->a = sub_dbldbl(this->a, a.a); + return *this; + } + + __device__ __host__ void operator*=(const doubledouble &a) + { + this->a = mul_dbldbl(this->a, a.a); + } + + __device__ __host__ void operator/=(const doubledouble &a) + { + this->a = div_dbldbl(this->a, a.a); + } + + constexpr double head() const { return a.y; } + constexpr double tail() const { return a.x; } __device__ __host__ void print() const { printf("scalar: %16.14e + %16.14e\n", head(), tail()); } + explicit constexpr operator double() const { return head(); } + explicit constexpr operator float() const { return static_cast(head()); } + explicit constexpr operator int() const { return static_cast(head()); } }; -__device__ inline bool operator>(const doubledouble &a, const double &b) { - return a.head() > b; +__device__ __host__ inline doubledouble sqrt(const doubledouble &a) { return doubledouble(sqrt_dbldbl(a.a)); } + +__device__ __host__ inline doubledouble operator-(const doubledouble &a) { return doubledouble(neg_dbldbl(a.a)); } + +__device__ __host__ inline doubledouble abs(const doubledouble &a) { return (a.head() < 0 ? -a : a); } + +__device__ __host__ inline bool isinf(const doubledouble &a) { return isinf(a.head()); } + +__device__ __host__ inline bool isnan(const doubledouble &a) { return isnan(a.head()); } + +__device__ __host__ inline bool isfinite(const doubledouble &a) { return isfinite(a.head()); } + +__device__ __host__ inline bool operator>(const doubledouble &a, const doubledouble &b) +{ + if (a.head() > b.head()) { + return true; + } else if (a.head() == b.head() && a.tail() > b.tail()) { + return true; + } else { + return false; + } } -__device__ inline doubledouble operator+(const doubledouble &a, const doubledouble &b) { - return doubledouble(add_dbldbl(a.a,b.a)); +__device__ __host__ inline bool operator>=(const doubledouble &a, const doubledouble &b) +{ + if (a.head() >= b.head()) { + return true; + } else if (a.head() == b.head() && a.tail() >= b.tail()) { + return true; + } else { + return false; + } } -__device__ inline doubledouble operator-(const doubledouble &a, const doubledouble &b) { - return doubledouble(sub_dbldbl(a.a,b.a)); +__device__ __host__ inline bool operator<(const doubledouble &a, const doubledouble &b) +{ + if (a.head() < b.head()) { + return true; + } else if (a.head() == b.head() && a.tail() < b.tail()) { + return true; + } else { + return false; + } } -__device__ inline doubledouble operator*(const doubledouble &a, const doubledouble &b) { - return doubledouble(mul_dbldbl(a.a,b.a)); +__device__ __host__ inline bool operator<=(const doubledouble &a, const doubledouble &b) +{ + if (a.head() <= b.head()) { + return true; + } else if (a.head() == b.head() && a.tail() <= b.tail()) { + return true; + } else { + return false; + } } -__device__ inline doubledouble operator/(const doubledouble &a, const doubledouble &b) { - return doubledouble(div_dbldbl(a.a,b.a)); +__device__ __host__ inline bool operator==(const doubledouble &a, const doubledouble &b) { + return (a.head() == b.head() && a.tail() == b.tail()); } -__device__ inline doubledouble add_double_to_doubledouble(const double &a, const double &b) { - return doubledouble(add_double_to_dbldbl(a,b)); +__device__ __host__ inline bool operator!=(const doubledouble &a, const doubledouble &b) { + return !(a == b); } -__device__ inline doubledouble mul_double_to_doubledouble(const double &a, const double &b) { - return doubledouble(mul_double_to_dbldbl(a,b)); +__device__ __host__ inline bool operator>(const doubledouble &a, const double &b) { + return a.head() > b || (a.head() == b && a.tail() > 0); } -struct doubledouble2 { - doubledouble x; - doubledouble y; +__device__ __host__ inline doubledouble operator+(const doubledouble &a, const doubledouble &b) { + return doubledouble(add_dbldbl(a.a,b.a)); +} - __device__ __host__ doubledouble2() : x(), y() { } - __device__ __host__ doubledouble2(const doubledouble2 &a) : x(a.x), y(a.y) { } - __device__ __host__ doubledouble2(const double2 &a) : x(a.x), y(a.y) { } - __device__ __host__ doubledouble2(const doubledouble &x, const doubledouble &y) : x(x), y(y) { } +__device__ __host__ inline doubledouble operator-(const doubledouble &a, const doubledouble &b) { + return doubledouble(sub_dbldbl(a.a,b.a)); +} - __device__ doubledouble2& operator+=(const doubledouble2 &a) { - x += a.x; - y += a.y; - return *this; - } +__device__ __host__ inline doubledouble operator*(const doubledouble &a, const doubledouble &b) { + return doubledouble(mul_dbldbl(a.a,b.a)); +} - __device__ __host__ void print() const { printf("vec2: (%16.14e + %16.14e) (%16.14e + %16.14e)\n", x.head(), x.tail(), y.head(), y.tail()); } -}; +__device__ __host__ inline doubledouble operator/(const doubledouble &a, const doubledouble &b) { + return doubledouble(div_dbldbl(a.a,b.a)); +} + +/** + @brief This isn't really an fma for double-double, but just + provides a convenient overload to ensure that when using native + floating point types that we consistently use an fma. +*/ +__device__ __host__ inline doubledouble fma(const doubledouble &a, const doubledouble &b, const doubledouble &c) { return a * b + c; } -struct doubledouble3 { +struct doubledouble2 { doubledouble x; doubledouble y; - doubledouble z; - __device__ __host__ doubledouble3() : x(), y() { } - __device__ __host__ doubledouble3(const doubledouble3 &a) : x(a.x), y(a.y), z(a.z) { } - __device__ __host__ doubledouble3(const double3 &a) : x(a.x), y(a.y), z(a.z) { } - __device__ __host__ doubledouble3(const doubledouble &x, const doubledouble &y, const doubledouble &z) : x(x), y(y), z(z) { } + doubledouble2() = default; + constexpr doubledouble2(const doubledouble2 &a) = default; + constexpr doubledouble2(const double2 &a) : x(a.x), y(a.y) { } + constexpr doubledouble2(const doubledouble &x, const doubledouble &y) : x(x), y(y) { } - __device__ doubledouble3& operator+=(const doubledouble3 &a) { + __device__ __host__ doubledouble2& operator+=(const doubledouble2 &a) { x += a.x; y += a.y; - z += a.z; return *this; } - __device__ __host__ void print() const { printf("vec3: (%16.14e + %16.14e) (%16.14e + %16.14e) (%16.14e + %16.14e)\n", x.head(), x.tail(), y.head(), y.tail(), z.head(), z.tail()); } + __device__ __host__ void print() const { printf("vec2: (%16.14e + %16.14e) (%16.14e + %16.14e)\n", x.head(), x.tail(), y.head(), y.tail()); } }; -__device__ doubledouble2 operator+(const doubledouble2 &a, const doubledouble2 &b) +__device__ __host__ inline doubledouble2 operator+(const doubledouble2 &a, const doubledouble2 &b) { return doubledouble2(a.x + b.x, a.y + b.y); } -__device__ doubledouble3 operator+(const doubledouble3 &a, const doubledouble3 &b) -{ return doubledouble3(a.x + b.x, a.y + b.y, a.z + b.z); } +inline std::ostream &operator<<(std::ostream &output, const doubledouble &a) +{ + output << "{" << a.head() << ", " << a.tail() << "}"; + return output; +} diff --git a/include/deflation.h b/include/deflation.h index b9d9132577..1f2aadb543 100644 --- a/include/deflation.h +++ b/include/deflation.h @@ -22,13 +22,13 @@ namespace quda { ColorSpinorField *RV; /** Inverse Ritz values*/ - double *invRitzVals; + real_t *invRitzVals; /** The Dirac operator to use for spinor deflation operation */ DiracMatrix &matDeflation; /** Host projection matrix (e.g. eigCG VH A V) */ - Complex *matProj; + complex_t *matProj; /** projection matrix leading dimension */ int ld; @@ -56,8 +56,8 @@ namespace quda { tot_dim = param.np; ld = ((tot_dim+15) / 16) * tot_dim; //allocate deflation resources: - matProj = static_cast(pool_pinned_malloc(ld * tot_dim * sizeof(Complex))); - invRitzVals = new double[tot_dim]; + matProj = static_cast(pool_pinned_malloc(ld * tot_dim * sizeof(complex_t))); + invRitzVals = new real_t[tot_dim]; //Check that RV is a composite field: if(RV->IsComposite() == false) errorQuda("\nRitz vectors must be contained in a composite field.\n"); @@ -128,7 +128,7 @@ namespace quda { @param tol : keep all eigenvectors with residual norm less then tol @param max_n_ev : keep the lowest max_n_ev eigenvectors (conservative) */ - void reduce(double tol, int max_n_ev); + void reduce(real_t tol, int max_n_ev); /** This applies deflation operation on a given spinor vector(s) diff --git a/include/dirac_quda.h b/include/dirac_quda.h index cfbba175d5..e0369eeaff 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -36,19 +36,19 @@ namespace quda { public: QudaDiracType type; - double kappa; - double mass; - double m5; // used by domain wall only + real_t kappa; + real_t mass; + real_t m5; // used by domain wall only int Ls; // used by domain wall and twisted mass - Complex b_5[QUDA_MAX_DWF_LS]; // used by mobius domain wall only - Complex c_5[QUDA_MAX_DWF_LS]; // used by mobius domain wall only + complex_t b_5[QUDA_MAX_DWF_LS]; // used by mobius domain wall only + complex_t c_5[QUDA_MAX_DWF_LS]; // used by mobius domain wall only // The EOFA parameters. See the description in InvertParam - double eofa_shift; + real_t eofa_shift; int eofa_pm; - double mq1; - double mq2; - double mq3; + real_t mq1; + real_t mq2; + real_t mq3; QudaMatPCType matpcType; QudaDagType dagger; @@ -59,10 +59,10 @@ namespace quda { CloverField *clover; GaugeField *xInvKD; // used for the Kahler-Dirac operator only - double mu; // used by twisted mass only - double mu_factor; // used by multigrid only - double epsilon; //2nd tm parameter (used by twisted mass only) - double tm_rho; // "rho"-type Hasenbusch mass used for twisted clover (like regular rho but + real_t mu; // used by twisted mass only + real_t mu_factor; // used by multigrid only + real_t epsilon; //2nd tm parameter (used by twisted mass only) + real_t tm_rho; // "rho"-type Hasenbusch mass used for twisted clover (like regular rho but // applied like a twisted mass and ignored in the inverse) int commDim[QUDA_MAX_DIM]; // whether to do comms or not @@ -114,24 +114,23 @@ namespace quda { void print() { printfQuda("Printing DslashParam\n"); printfQuda("type = %d\n", type); - printfQuda("kappa = %g\n", kappa); - printfQuda("mass = %g\n", mass); + printfQuda("kappa = %g\n", double(kappa)); + printfQuda("mass = %g\n", double(mass)); printfQuda("laplace3D = %d\n", laplace3D); - printfQuda("m5 = %g\n", m5); + printfQuda("m5 = %g\n", double(m5)); printfQuda("Ls = %d\n", Ls); printfQuda("matpcType = %d\n", matpcType); printfQuda("dagger = %d\n", dagger); - printfQuda("mu = %g\n", mu); - printfQuda("tm_rho = %g\n", tm_rho); - printfQuda("epsilon = %g\n", epsilon); + printfQuda("mu = %g\n", double(mu)); + printfQuda("tm_rho = %g\n", double(tm_rho)); + printfQuda("epsilon = %g\n", double(epsilon)); printfQuda("halo_precision = %d\n", halo_precision); for (int i=0; i(b_5[i].real()), static_cast(b_5[i].imag()), i, static_cast(c_5[i].real()), static_cast(c_5[i].imag())); printfQuda("setup_use_mma = %d\n", setup_use_mma); printfQuda("dslash_use_mma = %d\n", dslash_use_mma); - printfQuda("allow_truncation = %d\n", allow_truncation); + printfQuda("allow_truncation = %d\n", allow_truncation); printfQuda("use_mobius_fused_kernel = %s\n", use_mobius_fused_kernel ? "true" : "false"); } }; @@ -169,8 +168,8 @@ namespace quda { protected: GaugeField *gauge; - double kappa; - double mass; + real_t kappa; + real_t mass; int laplace3D; QudaMatPCType matpcType; mutable QudaDagType dagger; // mutable to simplify implementation of Mdag @@ -242,13 +241,13 @@ namespace quda { @brief Xpay version of Dslash */ virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const = 0; + const ColorSpinorField &x, const real_t &k) const = 0; /** @brief Xpay version of Dslash */ virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const + QudaParity parity, cvector_ref &x, real_t k) const { for (auto i = 0u; i < in.size(); i++) DslashXpay(out[i], in[i], parity, x[i], k); } @@ -258,7 +257,7 @@ namespace quda { smearing. */ virtual void SmearOp(ColorSpinorField &, const ColorSpinorField &, - const double, const double, const int, const QudaParity) const + const real_t, const real_t, const int, const QudaParity) const { errorQuda("Not implemented."); } @@ -371,7 +370,7 @@ namespace quda { */ virtual bool hasSpecialMG() const { return false; } - void setMass(double mass){ this->mass = mass;} + void setMass(real_t mass){ this->mass = mass;} // Dirac operator factory /** @@ -382,22 +381,22 @@ namespace quda { /** @brief accessor for Kappa (mass parameter) */ - double Kappa() const { return kappa; } + real_t Kappa() const { return kappa; } /** @brief accessor for Mass (in case of a factor of 2 for staggered) */ - virtual double Mass() const { return mass; } // in case of factor of 2 convention for staggered + virtual real_t Mass() const { return mass; } // in case of factor of 2 convention for staggered /** @brief accessor for twist parameter -- overrride can return better value */ - virtual double Mu() const { return 0.; } + virtual real_t Mu() const { return 0.; } /** @brief accessor for mu factoo for MG/ -- override can return a better value */ - virtual double MuFactor() const { return 0.; } + virtual real_t MuFactor() const { return 0.; } /** @brief accessor for if we let MG coarsening drop we can drop improvements, for ex long links for small aggregation dimensions @@ -495,7 +494,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, for ex dropping long links for * small aggregate sizes */ - virtual void createCoarseOp(GaugeField &, GaugeField &, const Transfer &, double, double, double, double, bool) const + virtual void createCoarseOp(GaugeField &, GaugeField &, const Transfer &, real_t, real_t, real_t, real_t, bool) const {errorQuda("Not implemented");} QudaPrecision HaloPrecision() const { return halo_precision; } @@ -527,7 +526,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -556,8 +555,8 @@ namespace quda { * @param kappa Kappa parameter for the coarse operator * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for Wilson operator */ - virtual void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., - double mu = 0., double mu_factor = 0., bool allow_truncation = false) const; + virtual void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass = 0., + real_t mu = 0., real_t mu_factor = 0., bool allow_truncation = false) const; }; // Even-odd preconditioned Wilson @@ -601,7 +600,7 @@ namespace quda { void Clover(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -646,8 +645,8 @@ namespace quda { * @param mass Mass parameter for the coarse operator (hard coded to 0 when CoarseOp is called) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover operator */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass = 0., real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -678,7 +677,7 @@ namespace quda { // out = x + k A_pp^{-1} D_p\bar{p} void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; // Can implement: M as e.g. : i) tmp_e = A^{-1}_ee D_eo in_o (Dslash) // ii) out_o = in_o + A_oo^{-1} D_oe tmp_e (AXPY) @@ -708,8 +707,8 @@ namespace quda { * @param mass Mass parameter for the coarse operator (set to zero) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover operator */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass = 0., real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -734,7 +733,7 @@ namespace quda { { protected: - double mu; + real_t mu; public: DiracCloverHasenbuschTwist(const DiracParam ¶m); @@ -757,15 +756,15 @@ namespace quda { * @param mass Mass parameter for the coarse operator (hard coded to 0 when CoarseOp is called) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass = 0., real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; }; // Even-odd preconditioned clover class DiracCloverHasenbuschTwistPC : public DiracCloverPC { protected: - double mu; + real_t mu; public: DiracCloverHasenbuschTwistPC(const DiracParam ¶m); @@ -783,11 +782,11 @@ namespace quda { // out = (1 +/- ig5 mu A)x + k A^{-1} D in void DslashXpayTwistClovInv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k, const double &b) const; + const ColorSpinorField &x, const real_t &k, const real_t &b) const; // out = ( 1+/- i g5 mu A) x - D in void DslashXpayTwistNoClovInv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k, const double &b) const; + const ColorSpinorField &x, const real_t &k, const real_t &b) const; // Can implement: M as e.g. : i) tmp_e = A^{-1}_ee D_eo in_o (Dslash) // ii) out_o = in_o + A_oo^{-1} D_oe tmp_e (AXPY) @@ -811,16 +810,16 @@ namespace quda { * @param mass Mass parameter for the coarse operator (set to zero) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover hasenbusch */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass = 0., real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; }; // Full domain wall class DiracDomainWall : public DiracWilson { protected: - double m5; - double kappa5; + real_t m5; + real_t kappa5; int Ls; // length of the fifth dimension /** @@ -836,7 +835,7 @@ namespace quda { void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -886,8 +885,8 @@ namespace quda { void Dslash4(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; void Dslash5(ColorSpinorField &out, const ColorSpinorField &in) const; void Dslash4Xpay(ColorSpinorField &out, const ColorSpinorField &in, - const QudaParity parity, const ColorSpinorField &x, const double &k) const; - void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const double &k) const; + const QudaParity parity, const ColorSpinorField &x, const real_t &k) const; + void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const real_t &k) const; void M(ColorSpinorField &out, const ColorSpinorField &in) const; void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -910,7 +909,7 @@ namespace quda { DiracDomainWall4DPC &operator=(const DiracDomainWall4DPC &dirac); void M5inv(ColorSpinorField &out, const ColorSpinorField &in) const; - void M5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const double &k) const; + void M5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const real_t &k) const; void M(ColorSpinorField &out, const ColorSpinorField &in) const; void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -928,8 +927,8 @@ namespace quda { protected: //Mobius coefficients - Complex b_5[QUDA_MAX_DWF_LS]; - Complex c_5[QUDA_MAX_DWF_LS]; + complex_t b_5[QUDA_MAX_DWF_LS]; + complex_t c_5[QUDA_MAX_DWF_LS]; /** Whether we are using classical Mobius with constant real-valued @@ -938,9 +937,9 @@ namespace quda { */ bool zMobius; - double mobius_kappa_b; - double mobius_kappa_c; - double mobius_kappa; + real_t mobius_kappa_b; + real_t mobius_kappa_c; + real_t mobius_kappa; /** @brief Check whether the input and output are valid 5D fields. If zMobius, we require that they @@ -959,11 +958,11 @@ namespace quda { void Dslash5(ColorSpinorField &out, const ColorSpinorField &in) const; void Dslash4Xpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; void Dslash4preXpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, - const double &k) const; + const real_t &k) const; void Dslash5Xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, - const double &k) const; + const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -989,20 +988,20 @@ namespace quda { DiracMobiusPC& operator=(const DiracMobiusPC &dirac); void M5inv(ColorSpinorField &out, const ColorSpinorField &in) const; - void M5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const double &k) const; + void M5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const real_t &k) const; void Dslash4M5invM5pre(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; void Dslash4M5preM5inv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; void Dslash4M5invXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &a) const; + const ColorSpinorField &x, const real_t &a) const; void Dslash4M5preXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &a) const; + const ColorSpinorField &x, const real_t &a) const; void Dslash4XpayM5mob(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &a) const; + const ColorSpinorField &x, const real_t &a) const; void Dslash4M5preXpayM5mob(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &a) const; + const ColorSpinorField &x, const real_t &a) const; void Dslash4M5invXpayM5inv(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &a, ColorSpinorField &y) const; + const ColorSpinorField &x, const real_t &a, ColorSpinorField &y) const; void MdagMLocal(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1023,16 +1022,16 @@ namespace quda { protected: // The EOFA parameters - double m5inv_fac = 0.; - double sherman_morrison_fac = 0.; - double eofa_shift; + real_t m5inv_fac = 0.; + real_t sherman_morrison_fac = 0.; + real_t eofa_shift; int eofa_pm; - double mq1; - double mq2; - double mq3; - double eofa_u[QUDA_MAX_DWF_LS]; - double eofa_x[QUDA_MAX_DWF_LS]; - double eofa_y[QUDA_MAX_DWF_LS]; + real_t mq1; + real_t mq2; + real_t mq3; + real_t eofa_u[QUDA_MAX_DWF_LS]; + real_t eofa_x[QUDA_MAX_DWF_LS]; + real_t eofa_y[QUDA_MAX_DWF_LS]; /** @brief Check whether the input and output are valid 5D fields, and we require that they @@ -1044,7 +1043,7 @@ namespace quda { DiracMobiusEofa(const DiracParam ¶m); void m5_eofa(ColorSpinorField &out, const ColorSpinorField &in) const; - void m5_eofa_xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double a = -1.) const; + void m5_eofa_xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, real_t a = -1.) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1065,7 +1064,7 @@ namespace quda { void m5inv_eofa(ColorSpinorField &out, const ColorSpinorField &in) const; void m5inv_eofa_xpay(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, - double a = -1.) const; + real_t a = -1.) const; void M(ColorSpinorField &out, const ColorSpinorField &in) const; void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1086,12 +1085,12 @@ namespace quda { class DiracTwistedMass : public DiracWilson { protected: - mutable double mu; - mutable double epsilon; + mutable real_t mu; + mutable real_t epsilon; void twistedApply(ColorSpinorField &out, const ColorSpinorField &in, const QudaTwistGamma5Type twistType) const; virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; public: DiracTwistedMass(const DiracTwistedMass &dirac); @@ -1112,7 +1111,7 @@ namespace quda { virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_MASS_DIRAC; } - double Mu() const { return mu; } + real_t Mu() const { return mu; } /** * @brief Create the coarse twisted-mass operator @@ -1134,8 +1133,8 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted mass */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_trunation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_trunation = false) const; }; // Even-odd preconditioned twisted mass @@ -1152,7 +1151,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; void M(ColorSpinorField &out, const ColorSpinorField &in) const; void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1177,17 +1176,17 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted mass */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_truncation = false) const; }; // Full twisted mass with a clover term class DiracTwistedClover : public DiracWilson { protected: - double mu; - double epsilon; - double tm_rho; + real_t mu; + real_t epsilon; + real_t tm_rho; CloverField *clover; void checkParitySpinor(const ColorSpinorField &, const ColorSpinorField &) const; void twistedCloverApply(ColorSpinorField &out, const ColorSpinorField &in, const QudaTwistGamma5Type twistType, @@ -1203,7 +1202,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1216,7 +1215,7 @@ namespace quda { virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_CLOVER_DIRAC; } - double Mu() const { return mu; } + real_t Mu() const { return mu; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -1253,8 +1252,8 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted clover */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1281,19 +1280,19 @@ namespace quda { void TwistCloverInv(ColorSpinorField &out, const ColorSpinorField &in, const int parity) const; /** - @brief Convenience wrapper for single/doublet + @brief Convenience wrapper for single/real_tt */ void WilsonDslash(ColorSpinorField &out, const ColorSpinorField &in, QudaParity parity) const; /** - @brief Convenience wrapper for single/doublet + @brief Convenience wrapper for single/real_tt */ void WilsonDslashXpay(ColorSpinorField &out, const ColorSpinorField &in, QudaParity parity, - const ColorSpinorField &x, double k) const; + const ColorSpinorField &x, real_t k) const; virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; void M(ColorSpinorField &out, const ColorSpinorField &in) const; void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1320,8 +1319,8 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted clover */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1348,7 +1347,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1388,8 +1387,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; /** * @brief Create two-link staggered quark smearing operator @@ -1401,7 +1400,7 @@ namespace quda { * @param[in] t0 time-slice index * @param[in] parity Parity flag */ - void SmearOp(ColorSpinorField &out, const ColorSpinorField &in, const double a, const double b, const int t0, const QudaParity parity) const; + void SmearOp(ColorSpinorField &out, const ColorSpinorField &in, const real_t a, const real_t b, const int t0, const QudaParity parity) const; }; // Even-odd preconditioned staggered @@ -1449,8 +1448,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; }; // Kahler-Dirac preconditioned staggered @@ -1472,7 +1471,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1522,8 +1521,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu = 0., + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1552,7 +1551,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1616,8 +1615,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long links here */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor, bool allow_truncation) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1638,7 +1637,7 @@ namespace quda { * @param[in] t0 time-slice index * @param[in] parity Parity flag */ - void SmearOp(ColorSpinorField &out, const ColorSpinorField &in, const double a, const double b, const int t0, const QudaParity parity) const; + void SmearOp(ColorSpinorField &out, const ColorSpinorField &in, const real_t a, const real_t b, const int t0, const QudaParity parity) const; }; // Even-odd preconditioned staggered @@ -1686,8 +1685,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long links here */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor, bool allow_truncation) const; }; // Kahler-Dirac preconditioned staggered @@ -1708,7 +1707,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const; + const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -1758,8 +1757,8 @@ namespace quda { * @param mu_factor Mu scaling factor for the coarse operator (ignored for staggered) * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long for asqtad */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor, bool allow_truncation) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1780,9 +1779,9 @@ namespace quda { class DiracCoarse : public Dirac { protected: - double mass; - double mu; - double mu_factor; + real_t mass; + real_t mu; + real_t mu_factor; const Transfer *transfer; /** restrictor / prolongator defined here */ const Dirac *dirac; /** Parent Dirac operator */ const bool need_bidirectional; /** Whether or not to force a bi-directional build */ @@ -1839,9 +1838,9 @@ namespace quda { void createYhat(bool gpu = true) const; public: - double Mass() const { return mass; } - double Mu() const { return mu; } - double MuFactor() const { return mu_factor; } + real_t Mass() const { return mass; } + real_t Mu() const { return mu; } + real_t MuFactor() const { return mu_factor; } bool AllowTruncation() const { return allow_truncation; } /** @@ -1915,10 +1914,10 @@ namespace quda { @param[in] k scalar multiplier */ virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, real_t k) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const + const ColorSpinorField &x, const real_t &k) const { DslashXpay(cvector_ref {out}, cvector_ref {in}, parity, cvector_ref {x}, k); @@ -1973,8 +1972,8 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for coarse op */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_truncation = false) const; /** * @brief Create the precondtioned coarse operator @@ -2062,10 +2061,10 @@ namespace quda { @param[in] k scalar multiplier */ void DslashXpay(cvector_ref &out, cvector_ref &in, QudaParity parity, - cvector_ref &x, double k) const; + cvector_ref &x, real_t k) const; void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity, - const ColorSpinorField &x, const double &k) const + const ColorSpinorField &x, const real_t &k) const { DslashXpay(cvector_ref {out}, cvector_ref {in}, parity, cvector_ref {x}, k); @@ -2116,8 +2115,8 @@ namespace quda { * @param mu_factor multiplicative factor for the mu parameter * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for coarse op */ - void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, real_t kappa, real_t mass, real_t mu, + real_t mu_factor = 0., bool allow_truncation = false) const; /** @brief If managed memory and prefetch is enabled, prefetch @@ -2146,7 +2145,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, - const QudaParity parity, const ColorSpinorField &x, const double &k) const; + const QudaParity parity, const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -2204,7 +2203,7 @@ namespace quda { virtual void Dslash(ColorSpinorField &out, const ColorSpinorField &in, const QudaParity parity) const; virtual void DslashXpay(ColorSpinorField &out, const ColorSpinorField &in, - const QudaParity parity, const ColorSpinorField &x, const double &k) const; + const QudaParity parity, const ColorSpinorField &x, const real_t &k) const; virtual void M(ColorSpinorField &out, const ColorSpinorField &in) const; virtual void MdagM(ColorSpinorField &out, const ColorSpinorField &in) const; @@ -2301,7 +2300,7 @@ namespace quda { const Dirac *Expose() const { return dirac; } //! Shift term added onto operator (M/M^dag M/M M^dag + shift) - double shift; + real_t shift; }; class DiracM : public DiracMatrix diff --git a/include/dslash_quda.h b/include/dslash_quda.h index ea12741be0..f78c56a9b6 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -81,7 +81,7 @@ namespace quda @param[in] comm_override Override for which dimensions are partitioned @param[in] profile The TimeProfile used for profiling the dslash */ - void ApplyWilson(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double kappa, + void ApplyWilson(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -107,7 +107,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyWilsonClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, - double kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); + real_t kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @brief Driver for applying the Wilson-clover stencil @@ -133,7 +133,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyWilsonCloverHasenbuschTwist(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A, double kappa, double mu, const ColorSpinorField &x, + const CloverField &A, real_t kappa, real_t mu, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -174,7 +174,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyWilsonCloverPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A, double kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, + const CloverField &A, real_t kappa, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -223,7 +223,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyWilsonCloverHasenbuschTwistPCClovInv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A, double kappa, double mu, + const CloverField &A, real_t kappa, real_t mu, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); @@ -251,12 +251,12 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyWilsonCloverHasenbuschTwistPCNoClovInv(ColorSpinorField &out, const ColorSpinorField &in, - const GaugeField &U, const CloverField &A, double kappa, double mu, + const GaugeField &U, const CloverField &A, real_t kappa, real_t mu, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); // old - void ApplyTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, + void ApplyTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); @@ -296,8 +296,8 @@ namespace quda @param[in] comm_override Override for which dimensions are partitioned @param[in] profile The TimeProfile used for profiling the dslash */ - void ApplyTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, + void ApplyTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, const int *comm_override, TimeProfile &profile); /** @@ -327,8 +327,8 @@ namespace quda @param[in] comm_override Override for which dimensions are partitioned @param[in] profile The TimeProfile used for profiling the dslash */ - void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, - double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); + void ApplyNdegTwistedMass(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, + real_t c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @brief Driver for applying the preconditioned non-degenerate @@ -377,7 +377,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyNdegTwistedMassPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - double a, double b, double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, + real_t a, real_t b, real_t c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, bool asymmetric, const int *comm_override, TimeProfile &profile); /** @@ -404,7 +404,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyTwistedClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &C, - double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, + real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -445,7 +445,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyTwistedCloverPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &C, double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, + const CloverField &C, real_t a, real_t b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -477,7 +477,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyNdegTwistedClover(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &C, double a, double b, double c, const ColorSpinorField &x, int parity, + const CloverField &C, real_t a, real_t b, real_t c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -519,7 +519,7 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ void ApplyNdegTwistedCloverPreconditioned(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &C, double a, double b, double c, bool xpay, + const CloverField &C, real_t a, real_t b, real_t c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); @@ -547,7 +547,7 @@ namespace quda @param[in] comm_override Override for which dimensions are partitioned @param[in] profile The TimeProfile used for profiling the dslash */ - void ApplyDomainWall5D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_f, + void ApplyDomainWall5D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t m_f, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -576,43 +576,43 @@ namespace quda @param[in] profile The TimeProfile used for profiling the dslash */ - void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, - const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, int parity, bool dagger, + void ApplyDomainWall4D(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t m_5, + const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); - void ApplyDomainWall4DM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5pre(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5pre(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5invM5pre(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5invM5pre(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5preM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5preM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5invM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5invM5inv(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5mob(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5mob(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); - void ApplyDomainWall4DM5preM5mob(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, - double m_5, const Complex *b_5, const Complex *c_5, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f, + void ApplyDomainWall4DM5preM5mob(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, + real_t m_5, const complex_t *b_5, const complex_t *c_5, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f, TimeProfile &profile); /** @brief Apply either the domain-wall / mobius Dslash5 operator or @@ -629,16 +629,16 @@ namespace quda @param[in] dagger Whether this is for the dagger operator @param[in] type Type of dslash we are applying */ - void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, - double m_5, const Complex *b_5, const Complex *c_5, double a, bool dagger, Dslash5Type type); + void ApplyDslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, real_t m_f, + real_t m_5, const complex_t *b_5, const complex_t *c_5, real_t a, bool dagger, Dslash5Type type); // The EOFA stuff namespace mobius_eofa { - void apply_dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, - double m_5, const Complex *b_5, const Complex *c_5, double a, int eofa_pm, double inv, - double kappa, const double *eofa_u, const double *eofa_x, const double *eofa_y, - double sherman_morrison, bool dagger, Dslash5Type type); + void apply_dslash5(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, real_t m_f, + real_t m_5, const complex_t *b_5, const complex_t *c_5, real_t a, int eofa_pm, real_t inv, + real_t kappa, const real_t *eofa_u, const real_t *eofa_x, const real_t *eofa_y, + real_t sherman_morrison, bool dagger, Dslash5Type type); } /** @@ -660,7 +660,7 @@ namespace quda @param[in] b Scale factor applied to aux field @param[in] x Vector field we accumulate onto to */ - void ApplyLaplace(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, double a, double b, + void ApplyLaplace(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -707,7 +707,7 @@ namespace quda @param[in] dagger Whether we are applying the dagger or not @param[in] improved whether to apply the standard-staggered (false) or asqtad (true) operator */ - void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, + void ApplyStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -723,7 +723,7 @@ namespace quda @param[in] improved whether to apply the standard-staggered (false) or asqtad (true) operator */ void ApplyImprovedStaggered(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const GaugeField &L, double a, const ColorSpinorField &x, int parity, bool dagger, + const GaugeField &L, real_t a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override, TimeProfile &profile); /** @@ -763,8 +763,8 @@ namespace quda @param[in] dagger Whether we are applying the dagger or not @param[in] twist The type of kernel we are doing */ - void ApplyTwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, double kappa, double mu, - double epsilon, int dagger, QudaTwistGamma5Type type); + void ApplyTwistGamma(ColorSpinorField &out, const ColorSpinorField &in, int d, real_t kappa, real_t mu, + real_t epsilon, int dagger, QudaTwistGamma5Type type); /** @brief Apply twisted clover-matrix field to a color-spinor field @@ -781,7 +781,7 @@ namespace quda else if (twist == QUDA_TWIST_GAMMA5_INVERSE) apply (Clover + i*a*gamma_5)/(Clover^2 + a^2) to the input spinor */ void ApplyTwistClover(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, - double kappa, double mu, double epsilon, int parity, int dagger, QudaTwistGamma5Type twist); + real_t kappa, real_t mu, real_t epsilon, int parity, int dagger, QudaTwistGamma5Type twist); /** @brief Dslash face packing routine @@ -798,7 +798,7 @@ namespace quda @param[in] stream Which stream are we executing in */ void PackGhost(void *ghost[2 * QUDA_MAX_DIM], const ColorSpinorField &field, MemoryLocation location, int nFace, - bool dagger, int parity, bool spin_project, double a, double b, double c, int shmem, + bool dagger, int parity, bool spin_project, real_t a, real_t b, real_t c, int shmem, const qudaStream_t &stream); /** diff --git a/include/eigensolve_quda.h b/include/eigensolve_quda.h index 06472f069b..f615aa8b11 100644 --- a/include/eigensolve_quda.h +++ b/include/eigensolve_quda.h @@ -28,7 +28,7 @@ namespace quda int n_kr; /** Size of Krylov space after extension */ int n_conv; /** Number of converged eigenvalues requested */ int n_ev_deflate; /** Number of converged eigenvalues to use in deflation */ - double tol; /** Tolerance on eigenvalues */ + real_t tol; /** Tolerance on eigenvalues */ bool reverse; /** True if using polynomial acceleration */ char spectrum[3]; /** Part of the spectrum to be computed */ bool compute_svd; /** Compute the SVD if requested **/ @@ -51,7 +51,7 @@ namespace quda int num_locked; int num_keep; - std::vector residua; + std::vector residua; // Device side vector workspace std::vector r; @@ -82,7 +82,7 @@ namespace quda @param kSpace The converged eigenvectors @param evals The converged eigenvalues */ - virtual void operator()(std::vector &kSpace, std::vector &evals) = 0; + virtual void operator()(std::vector &kSpace, std::vector &evals) = 0; /** @brief Creates the eigensolver using the parameters given and the matrix. @@ -110,7 +110,7 @@ namespace quda @param[in] kSpace The Krylov space vectors @param[in] evals The eigenvalue array */ - void prepareKrylovSpace(std::vector &kSpace, std::vector &evals); + void prepareKrylovSpace(std::vector &kSpace, std::vector &evals); /** @brief Set the epsilon parameter @@ -135,7 +135,7 @@ namespace quda @param[in] kSpace The Krylov space vectors @param[in] evals The eigenvalue array */ - void cleanUpEigensolver(std::vector &kSpace, std::vector &evals); + void cleanUpEigensolver(std::vector &kSpace, std::vector &evals); /** @brief Promoted the specified matVec operation: @@ -151,7 +151,7 @@ namespace quda @param[in] out Output spinor @param[in] in Input spinor */ - double estimateChebyOpMax(ColorSpinorField &out, ColorSpinorField &in); + real_t estimateChebyOpMax(ColorSpinorField &out, ColorSpinorField &in); /** @brief Orthogonalise input vectors r against @@ -240,7 +240,7 @@ namespace quda @param[in] accumulate Whether to preserve the sol vector content prior to accumulating */ void deflate(cvector_ref &sol, cvector_ref &src, - cvector_ref &evecs, const std::vector &evals, + cvector_ref &evecs, const std::vector &evals, bool accumulate = false) const; /** @@ -253,7 +253,7 @@ namespace quda @param[in] accumulate Whether to preserve the sol vector content prior to accumulating */ void deflateSVD(cvector_ref &sol, cvector_ref &vec, - cvector_ref &evecs, const std::vector &evals, + cvector_ref &evecs, const std::vector &evals, bool accumulate = false) const; /** @@ -261,7 +261,7 @@ namespace quda @param[in] evecs Computed eigenvectors of NormOp @param[in] evals Computed eigenvalues of NormOp */ - void computeSVD(std::vector &evecs, std::vector &evals); + void computeSVD(std::vector &evecs, std::vector &evals); /** @brief Compute eigenvalues and their residiua @@ -270,8 +270,7 @@ namespace quda @param[in] evals The eigenvalues @param[in] size The number of eigenvalues to compute */ - void computeEvals(std::vector &evecs, std::vector &evals, - int size); + void computeEvals(std::vector &evecs, std::vector &evals, int size); /** @brief Compute eigenvalues and their residiua. This variant compute the number of converged eigenvalues. @@ -279,7 +278,7 @@ namespace quda @param[in] evecs The eigenvectors @param[in] evals The eigenvalues */ - void computeEvals(std::vector &evecs, std::vector &evals) + void computeEvals(std::vector &evecs, std::vector &evals) { computeEvals(evecs, evals, n_conv); } @@ -290,7 +289,7 @@ namespace quda @param[in] eig_vecs The eigenvectors to save @param[in] file The filename to save */ - void loadFromFile(std::vector &eig_vecs, std::vector &evals); + void loadFromFile(std::vector &eig_vecs, std::vector &evals); /** @brief Sort array the first n elements of x according to spec_type, y comes along for the ride @@ -299,7 +298,7 @@ namespace quda @param[in] x The array to sort @param[in] y An array whose elements will be permuted in tandem with x */ - void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); + void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); /** @brief Sort array the first n elements of x according to spec_type, y comes along for the ride @@ -309,7 +308,7 @@ namespace quda @param[in] x The array to sort @param[in] y An array whose elements will be permuted in tandem with x */ - void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); + void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); /** @brief Sort array the first n elements of x according to spec_type, y comes along for the ride @@ -319,7 +318,7 @@ namespace quda @param[in] x The array to sort @param[in] y An array whose elements will be permuted in tandem with x */ - void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); + void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); /** @brief Sort array the first n elements of x according to spec_type, y comes along for the ride @@ -330,7 +329,7 @@ namespace quda @param[in] x The array to sort @param[in] y An array whose elements will be permuted in tandem with x */ - void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); + void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); /** @brief Sort array the first n elements of x according to spec_type, y comes along for the ride @@ -341,7 +340,7 @@ namespace quda @param[in] x The array to sort @param[in] y An array whose elements will be permuted in tandem with x */ - void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); + void sortArrays(QudaEigSpectrumType spec_type, int n, std::vector &x, std::vector &y); }; /** @@ -365,18 +364,18 @@ namespace quda virtual bool hermitian() { return true; } /** TRLM is only for Hermitian systems */ // Variable size matrix - std::vector ritz_mat; + std::vector ritz_mat; // Tridiagonal/Arrow matrix, fixed size. - std::vector alpha; - std::vector beta; + std::vector alpha; + std::vector beta; /** @brief Compute eigenpairs @param[in] kSpace Krylov vector space @param[in] evals Computed eigenvalues */ - void operator()(std::vector &kSpace, std::vector &evals); + void operator()(std::vector &kSpace, std::vector &evals); /** @brief Lanczos step: extends the Krylov space. @@ -421,14 +420,14 @@ namespace quda virtual bool hermitian() { return true; } /** (BLOCK)TRLM is only for Hermitian systems */ // Variable size matrix - std::vector block_ritz_mat; + std::vector block_ritz_mat; /** Block Tridiagonal/Arrow matrix, fixed size. */ - std::vector block_alpha; - std::vector block_beta; + std::vector block_alpha; + std::vector block_beta; /** Temp storage used in blockLanczosStep, fixed size. */ - std::vector jth_block; + std::vector jth_block; /** Size of blocks of data in alpha/beta */ int block_data_length; @@ -438,7 +437,7 @@ namespace quda @param[in] kSpace Krylov vector space @param[in] evals Computed eigenvalues */ - void operator()(std::vector &kSpace, std::vector &evals); + void operator()(std::vector &kSpace, std::vector &evals); /** @brief block lanczos step: extends the Krylov space in block step @@ -474,9 +473,9 @@ namespace quda { public: - std::vector> upperHess; - std::vector> Qmat; - std::vector> Rmat; + std::vector> upperHess; + std::vector> Qmat; + std::vector> Rmat; /** @brief Constructor for Thick Restarted Eigensolver class @@ -496,7 +495,7 @@ namespace quda @param[in] kSpace Krylov vector space @param[in] evals Computed eigenvalues */ - void operator()(std::vector &kSpace, std::vector &evals); + void operator()(std::vector &kSpace, std::vector &evals); /** @brief Arnoldi step: extends the Krylov space by one vector @@ -505,14 +504,14 @@ namespace quda @param[in] beta Norm of residual vector @param[in] j Index of vector being computed */ - void arnoldiStep(std::vector &v, std::vector &r, double &beta, int j); + void arnoldiStep(std::vector &v, std::vector &r, real_t &beta, int j); /** @brief Get the eigendecomposition from the upper Hessenberg matrix via QR - @param[in] evals Complex eigenvalues + @param[in] evals complex_t eigenvalues @param[in] beta Norm of residual (used to compute errors on eigenvalues) */ - void eigensolveFromUpperHess(std::vector &evals, const double beta); + void eigensolveFromUpperHess(std::vector &evals, const real_t beta); /** @brief Rotate the Krylov space @@ -526,14 +525,14 @@ namespace quda @param[in] evals The shifts to apply @param[in] num_shifts The number of shifts to apply */ - void qrShifts(const std::vector evals, const int num_shifts); + void qrShifts(const std::vector evals, const int num_shifts); /** @brief Apply One step of the the QR algorithm @param[in] Q The Q matrix @param[in] R The R matrix */ - void qrIteration(std::vector> &Q, std::vector> &R); + void qrIteration(std::vector> &Q, std::vector> &R); /** @brief Reorder the Krylov space and eigenvalues @@ -542,7 +541,7 @@ namespace quda @param[in] spec_type The spectrum type (Largest/Smallest)(Modulus/Imaginary/Real) that determines the sorting condition */ - void reorder(std::vector &kSpace, std::vector &evals, + void reorder(std::vector &kSpace, std::vector &evals, const QudaEigSpectrumType spec_type); }; @@ -561,7 +560,7 @@ namespace quda @param[in] eig_param Parameter structure for all QUDA eigensolvers @param[in,out] profile TimeProfile instance used for profiling */ - void arpack_solve(std::vector &h_evecs, std::vector &h_evals, const DiracMatrix &mat, + void arpack_solve(std::vector &h_evecs, std::vector &h_evals, const DiracMatrix &mat, QudaEigParam *eig_param, TimeProfile &profile); } // namespace quda diff --git a/include/float_vector.h b/include/float_vector.h index 201d803c95..9bc39d021d 100644 --- a/include/float_vector.h +++ b/include/float_vector.h @@ -12,9 +12,19 @@ #include #include #include +#include "dbldbl.h" +#include "reproducible_floating_accumulator.hpp" namespace quda { +#if defined(QUDA_REDUCTION_ALGORITHM_NAIVE) + using device_reduce_t = reduction_t; +#elif defined(QUDA_REDUCTION_ALGORITHM_KAHAN) + using device_reduce_t = kahan_t; +#elif defined(QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE) + using device_reduce_t = rfa_t; +#endif + __host__ __device__ inline double2 operator+(const double2 &x, const double2 &y) { return make_double2(x.x + y.x, x.y + y.y); @@ -35,6 +45,15 @@ namespace quda { return make_float2(x.x + y.x, x.y + y.y); } +#ifdef QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE + __host__ __device__ inline device_reduce_t operator+(const device_reduce_t &x, const device_reduce_t &y) + { + device_reduce_t z = x; + z.operator+=(y); + return z; + } +#endif + template __device__ __host__ inline array operator+(const array &a, const array &b) { @@ -44,80 +63,25 @@ namespace quda { return c; } - template constexpr std::enable_if_t, T> zero() { return static_cast(0); } - template constexpr std::enable_if_t>, T> zero() - { - return static_cast(0); - } - - template using specialize = std::enable_if_t, U>; - - template constexpr specialize zero() { return double2 {0.0, 0.0}; } - template constexpr specialize zero() { return double3 {0.0, 0.0, 0.0}; } - template constexpr specialize zero() { return double4 {0.0, 0.0, 0.0, 0.0}; } - - template constexpr specialize zero() { return float2 {0.0f, 0.0f}; } - template constexpr specialize zero() { return float3 {0.0f, 0.0f, 0.0f}; } - template constexpr specialize zero() { return float4 {0.0f, 0.0f, 0.0f, 0.0f}; } - -#ifdef QUAD_SUM - template __device__ __host__ inline specialize zero() { return doubledouble(); } - template __device__ __host__ inline specialize zero() { return doubledouble2(); } - template __device__ __host__ inline specialize zero() { return doubledouble3(); } -#endif - - template __device__ __host__ inline array zero() - { - array v; -#pragma unroll - for (int i = 0; i < n; i++) v[i] = zero(); - return v; - } - - // array of arithmetic types specialization - template - __device__ __host__ inline std::enable_if_t< - std::is_same_v> && std::is_arithmetic_v, T> - zero() - { - return zero(); - } - - // array of array specialization - template - __device__ __host__ inline std::enable_if_t< - std::is_same_v, T::N>>, T> - zero() - { - T v; -#pragma unroll - for (int i = 0; i < v.size(); i++) v[i] = zero(); - return v; - } - - // array of complex specialization - template - __device__ - __host__ inline std::enable_if_t, T::N>>, T> - zero() - { - T v; -#pragma unroll - for (int i = 0; i < v.size(); i++) v[i] = zero(); - return v; - } - /** Container used when we want to track the reference value when computing an infinity norm */ template struct deviation_t { + using value_type = T; + T diff; T ref; + + template constexpr deviation_t &operator=(const deviation_t &other) + { + diff = T(other.diff); + ref = T(other.ref); + return *this; + } }; - template constexpr specialize> zero() { return {0.0, 0.0}; } - template constexpr specialize> zero() { return {0.0f, 0.0f}; } + template struct get_scalar> { using type = typename get_scalar::type; }; template __host__ __device__ inline bool operator>(const deviation_t &a, const deviation_t &b) { @@ -128,6 +92,10 @@ namespace quda { static constexpr std::enable_if_t, T> value() { return std::numeric_limits::lowest(); } }; + template <> struct low { + static constexpr doubledouble value() { return std::numeric_limits::lowest(); } + }; + template struct low> { static inline __host__ __device__ array value() { @@ -146,6 +114,10 @@ namespace quda { static constexpr std::enable_if_t, T> value() { return std::numeric_limits::max(); } }; + template <> struct high { + static constexpr doubledouble value() { return std::numeric_limits::max(); } + }; + template struct RealType { }; template <> struct RealType { diff --git a/include/gauge_field.h b/include/gauge_field.h index bf75bc6bfa..700466df70 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -46,8 +46,8 @@ namespace quda { QudaTboundary t_boundary = QUDA_INVALID_T_BOUNDARY; QudaReconstructType reconstruct = QUDA_RECONSTRUCT_NO; - double anisotropy = 1.0; - double tadpole = 1.0; + real_t anisotropy = 1.0; + real_t tadpole = 1.0; GaugeField *field = nullptr; // pointer to a pre-allocated field void *gauge = nullptr; // used when we use a reference to an external field @@ -67,7 +67,7 @@ namespace quda { bool staggeredPhaseApplied = false; /** Imaginary chemical potential */ - double i_mu = 0.0; + real_t i_mu = 0.0; /** Offset into MILC site struct to the desired matrix field (only if gauge_order=MILC_SITE_GAUGE_ORDER) */ size_t site_offset = 0; @@ -188,9 +188,9 @@ namespace quda { QudaLinkType link_type = QUDA_INVALID_LINKS; QudaTboundary t_boundary = QUDA_INVALID_T_BOUNDARY; - double anisotropy = 0.0; - double tadpole = 0.0; - double fat_link_max = 0.0; + real_t anisotropy = 0.0; + real_t tadpole = 0.0; + real_t fat_link_max = 0.0; mutable array ghost = {}; // stores the ghost zone of the gauge field (non-native fields only) @@ -210,7 +210,7 @@ namespace quda { /** Imaginary chemical potential */ - double i_mu = 0.0; + real_t i_mu = 0.0; /** Offset into MILC site struct to the desired matrix field (only if gauge_order=MILC_SITE_GAUGE_ORDER) @@ -351,8 +351,8 @@ namespace quda { int Ncolor() const { return nColor; } QudaReconstructType Reconstruct() const { return reconstruct; } QudaGaugeFieldOrder Order() const { return order; } - double Anisotropy() const { return anisotropy; } - double Tadpole() const { return tadpole; } + real_t Anisotropy() const { return anisotropy; } + real_t Tadpole() const { return tadpole; } QudaTboundary TBoundary() const { return t_boundary; } QudaLinkType LinkType() const { return link_type; } QudaGaugeFixed GaugeFixed() const { return fixed; } @@ -382,9 +382,9 @@ namespace quda { /** Return the imaginary chemical potential applied to this field */ - double iMu() const { return i_mu; } + real_t iMu() const { return i_mu; } - const double& LinkMax() const { return fat_link_max; } + const real_t& LinkMax() const { return fat_link_max; } int Nface() const { return nFace; } /** @@ -518,28 +518,28 @@ namespace quda { @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) @return L1 norm */ - double norm1(int dim = -1, bool fixed = false) const; + real_t norm1(int dim = -1, bool fixed = false) const; /** @brief Compute the L2 norm squared of the field @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) @return L2 norm squared */ - double norm2(int dim = -1, bool fixed = false) const; + real_t norm2(int dim = -1, bool fixed = false) const; /** @brief Compute the absolute maximum of the field (Linfinity norm) @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) @return Absolute maximum value */ - double abs_max(int dim = -1, bool fixed = false) const; + real_t abs_max(int dim = -1, bool fixed = false) const; /** @brief Compute the absolute minimum of the field @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) @return Absolute minimum value */ - double abs_min(int dim = -1, bool fixed = false) const; + real_t abs_min(int dim = -1, bool fixed = false) const; /** Compute checksum of this gauge field: this uses a XOR-based checksum method @@ -608,7 +608,7 @@ namespace quda { @param u The gauge field that we want the norm of @return The L1 norm of the gauge field */ - double norm1(const GaugeField &u); + real_t norm1(const GaugeField &u); /** @brief This is a debugging function, where we cast a gauge field @@ -616,14 +616,14 @@ namespace quda { @param u The gauge field that we want the norm of @return The L2 norm squared of the gauge field */ - double norm2(const GaugeField &u); + real_t norm2(const GaugeField &u); /** @brief Scale the gauge field by the scalar a. @param[in] a scalar multiplier @param[in] u The gauge field we want to multiply */ - void ax(const double &a, GaugeField &u); + void ax(const real_t &a, GaugeField &u); /** This function is used for extracting the gauge ghost zone from a diff --git a/include/gauge_field_order.h b/include/gauge_field_order.h index 17eb70b42c..8454f60ba3 100644 --- a/include/gauge_field_order.h +++ b/include/gauge_field_order.h @@ -377,7 +377,7 @@ namespace quda { { for (int d = 0; d < U.Geometry(); d++) u[d] = gauge_ ? static_cast **>(gauge_)[d] : U.data *>(d); - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -418,7 +418,7 @@ namespace quda { in transform_reduce */ template - __host__ double transform_reduce(QudaFieldLocation location, int dim, helper h) const + auto transform_reduce(QudaFieldLocation location, int dim, helper h) const { if (dim >= geometry) errorQuda("Request dimension %d exceeds dimensionality of the field %d", dim, geometry); int lower = (dim == -1) ? 0 : dim; @@ -457,7 +457,7 @@ namespace quda { ghostOffset[d+4] = U.Nface()*U.SurfaceCB(d)*U.Ncolor()*U.Ncolor(); } - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -492,7 +492,7 @@ namespace quda { scale(static_cast(1.0)), scale_inv(static_cast(1.0)) { - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -537,12 +537,11 @@ namespace quda { in transform_reduce */ template - __host__ double transform_reduce(QudaFieldLocation location, int dim, helper h) const + auto transform_reduce(QudaFieldLocation location, int dim, helper h) const { if (dim >= geometry) errorQuda("Request dimension %d exceeds dimensionality of the field %d", dim, geometry); auto count = (dim == -1 ? geometry : 1) * volumeCB * nColor * nColor; // items per parity - auto init = reducer::init(); - std::vector result = {init, init}; + std::vector result = {reducer::init(), reducer::init()}; std::vector v = {u + 0 * volumeCB * geometry * nColor * nColor, u + 1 * volumeCB * geometry * nColor * nColor}; if (dim == -1) { @@ -578,7 +577,7 @@ namespace quda { ghostOffset[d+4] = U.Nface()*U.SurfaceCB(d)*U.Ncolor()*U.Ncolor(); } - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -631,7 +630,7 @@ namespace quda { scale(static_cast(1.0)), scale_inv(static_cast(1.0)) { - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -673,13 +672,12 @@ namespace quda { in transform_reduce */ template - __host__ double transform_reduce(QudaFieldLocation location, int dim, helper h) const + auto transform_reduce(QudaFieldLocation location, int dim, helper h) const { if (dim >= geometry) errorQuda("Requested dimension %d exceeds dimensionality of the field %d", dim, geometry); auto start = (dim == -1) ? 0 : dim; auto count = (dim == -1 ? geometry : 1) * stride * nColor * nColor; - auto init = reducer::init(); - std::vector result = {init, init}; + std::vector result = {reducer::init(), reducer::init()}; std::vector v = {u + 0 * offset_cb + start * count, u + 1 * offset_cb + start * count}; ::quda::transform_reduce(location, result, v, count, h); return reducer::apply(result[0], result[1]); @@ -708,7 +706,7 @@ namespace quda { ghost[d+4] = !native_ghost && U.Geometry() == QUDA_COARSE_GEOMETRY? static_cast*>(ghost_[d+4]) : nullptr; ghostVolumeCB[d+4] = U.Nface()*U.SurfaceCB(d); } - resetScale(U.Scale()); + resetScale(double(U.Scale())); } void resetScale(Float max) @@ -892,10 +890,10 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) * @return L1 norm */ - __host__ double norm1(int dim=-1, bool global=true) const { + __host__ real_t norm1(int dim=-1, bool global=true) const { commGlobalReductionPush(global); - double nrm1 = accessor.template transform_reduce>(location, dim, - abs_(accessor.scale_inv)); + real_t nrm1 = real_t(accessor.template transform_reduce> + (location, dim, abs_(accessor.scale_inv))); commGlobalReductionPop(); return nrm1; } @@ -905,11 +903,11 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) * @return L2 norm squared */ - __host__ double norm2(int dim = -1, bool global = true) const + __host__ real_t norm2(int dim = -1, bool global = true) const { commGlobalReductionPush(global); - double nrm2 = accessor.template transform_reduce>( - location, dim, square_(accessor.scale_inv)); + real_t nrm2 = real_t(accessor.template transform_reduce> + (location, dim, square_(accessor.scale_inv))); commGlobalReductionPop(); return nrm2; } @@ -919,11 +917,11 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) * @return Linfinity norm */ - __host__ double abs_max(int dim = -1, bool global = true) const + __host__ real_t abs_max(int dim = -1, bool global = true) const { commGlobalReductionPush(global); - double absmax = accessor.template transform_reduce>( - location, dim, abs_max_(accessor.scale_inv)); + real_t absmax = real_t(accessor.template transform_reduce> + (location, dim, abs_max_(accessor.scale_inv))); commGlobalReductionPop(); return absmax; } @@ -933,10 +931,10 @@ namespace quda { * @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) * @return Minimum norm */ - __host__ double abs_min(int dim = -1, bool global = true) const + __host__ real_t abs_min(int dim = -1, bool global = true) const { commGlobalReductionPush(global); - double absmin = accessor.template transform_reduce>( + real_t absmin = accessor.template transform_reduce>( location, dim, abs_min_(accessor.scale_inv)); commGlobalReductionPop(); return absmin; @@ -1270,7 +1268,7 @@ namespace quda { // scale factor is set when using recon-9 Reconstruct(const GaugeField &u, real scale = 1.0) : - anisotropy(u.Anisotropy() * scale, 1.0 / (u.Anisotropy() * scale)), + anisotropy(Float(u.Anisotropy() * scale), Float(1.0 / (u.Anisotropy() * scale))), tBoundary(static_cast(u.TBoundary()) * scale, 1.0 / (static_cast(u.TBoundary()) * scale)), firstTimeSliceBound(u.VolumeCB()), lastTimeSliceBound((u.X()[3] - 1) * u.X()[0] * u.X()[1] * u.X()[2] / 2), @@ -2201,7 +2199,7 @@ namespace quda { LegacyOrder(u, ghost_), gauge(gauge_ ? gauge_ : u.data()), volumeCB(u.VolumeCB()), - scale(u.Scale()), + scale(real(u.Scale())), scale_inv(1.0 / scale) { if constexpr (length != 18) errorQuda("Gauge length %d not supported", length); diff --git a/include/gauge_path_helper.cuh b/include/gauge_path_helper.cuh index 0281848609..af268bd69c 100644 --- a/include/gauge_path_helper.cuh +++ b/include/gauge_path_helper.cuh @@ -18,7 +18,7 @@ namespace quda { int *buffer; int count; - paths(std::vector& input_path, std::vector& length_h, std::vector& path_coeff_h, int num_paths, int max_length) : + paths(std::vector& input_path, std::vector& length_h, std::vector& path_coeff_h, int num_paths, int max_length) : num_paths(num_paths), max_length(max_length), count(0) @@ -52,8 +52,9 @@ namespace quda { // length array memcpy(path_h + dim * num_paths * max_length, length_h.data(), num_paths*sizeof(int)); - // path_coeff array - memcpy(path_h + dim * num_paths * max_length + num_paths + pad, path_coeff_h.data(), num_paths*sizeof(double)); + // path_coeff array (copy and convert if needed) + double *path_coeff_ = reinterpret_cast(path_h + dim * num_paths * max_length + num_paths + pad); + for (auto i = 0; i < num_paths; i++) path_coeff_[i] = double(path_coeff_h[i]); qudaMemcpy(buffer, path_h, bytes, qudaMemcpyHostToDevice); host_free(path_h); diff --git a/include/gauge_path_quda.h b/include/gauge_path_quda.h index db398c11a5..7b53af9acd 100644 --- a/include/gauge_path_quda.h +++ b/include/gauge_path_quda.h @@ -14,8 +14,8 @@ namespace quda @param[in] num_paths Numer of paths @param[in] max_length Maximum length of each path */ - void gaugeForce(GaugeField &mom, const GaugeField &u, double coeff, std::vector &input_path, - std::vector &length, std::vector &path_coeff, int num_paths, int max_length); + void gaugeForce(GaugeField &mom, const GaugeField &u, real_t coeff, std::vector &input_path, + std::vector &length, std::vector &path_coeff, int num_paths, int max_length); /** @brief Compute the product of gauge-links along the given path @@ -28,8 +28,8 @@ namespace quda @param[in] num_paths Numer of paths @param[in] max_length Maximum length of each path */ - void gaugePath(GaugeField &out, const GaugeField &u, double coeff, std::vector &input_path, - std::vector &length, std::vector &path_coeff, int num_paths, int max_length); + void gaugePath(GaugeField &out, const GaugeField &u, real_t coeff, std::vector &input_path, + std::vector &length, std::vector &path_coeff, int num_paths, int max_length); /** @brief Compute the trace of an arbitrary set of gauge loops @@ -42,8 +42,8 @@ namespace quda @param[in] num_paths Numer of paths @param[in] path_max_length Maximum length of each path */ - void gaugeLoopTrace(const GaugeField &u, std::vector &loop_traces, double factor, - std::vector &input_path, std::vector &length, std::vector &path_coeff_h, + void gaugeLoopTrace(const GaugeField &u, std::vector &loop_traces, real_t factor, + std::vector &input_path, std::vector &length, std::vector &path_coeff_h, int num_paths, int path_max_length); } // namespace quda diff --git a/include/gauge_tools.h b/include/gauge_tools.h index 9b7d68db37..349d0f6b7e 100644 --- a/include/gauge_tools.h +++ b/include/gauge_tools.h @@ -21,7 +21,7 @@ namespace quda * @param tol Tolerance to which the iterative algorithm works * @param fails Number of link failures (device pointer) */ - void projectSU3(GaugeField &U, double tol, int *fails); + void projectSU3(GaugeField &U, real_t tol, int *fails); /** @brief Compute the plaquette of the gauge field @@ -31,7 +31,7 @@ namespace quda temporal plaquette) site averages normalized such that each plaquette is in the range [0,1] */ - double3 plaquette(const GaugeField &U); + array plaquette(const GaugeField &U); /** @brief Generate Gaussian distributed su(N) or SU(N) fields. If U @@ -47,7 +47,7 @@ namespace quda @param[in] rngstate random states @param[in] sigma Width of Gaussian distrubution */ - void gaugeGauss(GaugeField &U, RNG &rngstate, double epsilon); + void gaugeGauss(GaugeField &U, RNG &rngstate, real_t epsilon); /** @brief Generate Gaussian distributed su(N) or SU(N) fields. If U @@ -63,7 +63,7 @@ namespace quda @param[in] seed The seed used for the RNG @param[in] sigma Wdith of the Gaussian distribution */ - void gaugeGauss(GaugeField &U, unsigned long long seed, double epsilon); + void gaugeGauss(GaugeField &U, unsigned long long seed, real_t epsilon); /** @brief Generate a random noise gauge field. This variant allows @@ -96,7 +96,7 @@ namespace quda @param[in] dataOr Input gauge field @param[in] alpha smearing parameter */ - void APEStep(GaugeField &dataDs, GaugeField &dataOr, double alpha); + void APEStep(GaugeField &dataDs, GaugeField &dataOr, real_t alpha); /** @brief Apply STOUT smearing to the gauge field @@ -105,7 +105,7 @@ namespace quda @param[in] dataOr Input gauge field @param[in] rho smearing parameter */ - void STOUTStep(GaugeField &dataDs, GaugeField &dataOr, double rho); + void STOUTStep(GaugeField &dataDs, GaugeField &dataOr, real_t rho); /** @brief Apply Over Improved STOUT smearing to the gauge field @@ -115,7 +115,7 @@ namespace quda @param[in] rho smearing parameter @param[in] epsilon smearing parameter */ - void OvrImpSTOUTStep(GaugeField &dataDs, GaugeField &dataOr, double rho, double epsilon); + void OvrImpSTOUTStep(GaugeField &dataDs, GaugeField &dataOr, real_t rho, real_t epsilon); /** @brief Apply Wilson Flow steps W1, W2, Vt to the gauge field. @@ -129,7 +129,7 @@ namespace quda @param[in] epsilon Step size @param[in] smear_type Wilson (1x1) or Symanzik improved (2x1) staples, else error */ - void WFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, double epsilon, QudaGaugeSmearType smear_type); + void WFlowStep(GaugeField &out, GaugeField &temp, GaugeField &in, real_t epsilon, QudaGaugeSmearType smear_type); /** * @brief Gauge fixing with overrelaxation with support for single and multi GPU. @@ -145,7 +145,7 @@ namespace quda * @param[in] stopWtheta, 0 for MILC criterion and 1 to use the theta value */ void gaugeFixingOVR(GaugeField &data, const int gauge_dir, const int Nsteps, const int verbose_interval, - const double relax_boost, const double tolerance, const int reunit_interval, const int stopWtheta); + const real_t relax_boost, const real_t tolerance, const int reunit_interval, const int stopWtheta); /** * @brief Gauge fixing with Steepest descent method with FFTs with support for single GPU only. @@ -161,7 +161,7 @@ namespace quda * @param[in] stopWtheta, 0 for MILC criterion and 1 to use the theta value */ void gaugeFixingFFT(GaugeField &data, const int gauge_dir, const int Nsteps, const int verbose_interval, - const double alpha, const int autotune, const double tolerance, const int stopWtheta); + const real_t alpha, const int autotune, const real_t tolerance, const int stopWtheta); /** @brief Compute the Fmunu tensor @@ -173,22 +173,22 @@ namespace quda /** @brief Compute the topological charge and field energy @param[out] energy The total, spatial, and temporal field energy - @param[out] qcharge The total topological charge @param[in] Fmunu The Fmunu tensor, usually calculated from a smeared configuration + @return The total topological charge */ - void computeQCharge(double energy[3], double &qcharge, const GaugeField &Fmunu); + real_t computeQCharge(array &energy, const GaugeField &Fmunu); /** @brief Compute the topological charge, field energy and the topological charge density per lattice site @param[out] energy The total, spatial, and temporal field energy - @param[out] qcharge The total topological charge @param[out] qdensity The topological charge at each lattice site @param[in] Fmunu The Fmunu tensor, usually calculated from a smeared configuration + @return The total topological charge */ - void computeQChargeDensity(double energy[3], double &qcharge, void *qdensity, const GaugeField &Fmunu); + real_t computeQChargeDensity(array &energy, void *qdensity, const GaugeField &Fmunu); /** * @brief Compute the trace of the Polyakov loop in a given dimension @@ -197,6 +197,6 @@ namespace quda * @param[in] dir The direction to compute the Polyakov loop in * @param[in] profile TimeProfile instance used for profiling. */ - void gaugePolyakovLoop(double ploop[2], const GaugeField &u, int dir, TimeProfile &profile); + array gaugePolyakovLoop(const GaugeField &u, int dir, TimeProfile &profile); } // namespace quda diff --git a/include/invert_quda.h b/include/invert_quda.h index 35a21ce31c..d7aec4775b 100644 --- a/include/invert_quda.h +++ b/include/invert_quda.h @@ -64,7 +64,7 @@ namespace quda { QudaComputeNullVector compute_null_vector; /**< Reliable update tolerance */ - double delta; + real_t delta; /**< Whether to user alternative reliable updates (CG only at the moment) */ bool use_alternative_reliable; @@ -109,13 +109,13 @@ namespace quda { int pipeline; /**< Solver tolerance in the L2 residual norm */ - double tol; + real_t tol; /**< Solver tolerance in the L2 residual norm */ - double tol_restart; + real_t tol_restart; /**< Solver tolerance in the heavy quark residual norm */ - double tol_hq; + real_t tol_hq; /**< Whether to compute the true residual post solve */ bool compute_true_res; @@ -124,10 +124,10 @@ namespace quda { bool sloppy_converge; /**< Actual L2 residual norm achieved in solver */ - double true_res; + real_t true_res; /**< Actual heavy quark residual norm achieved in solver */ - double true_res_hq; + real_t true_res_hq; /**< Maximum number of iterations in the linear solver */ int maxiter; @@ -166,22 +166,22 @@ namespace quda { int num_offset; /** Offsets for multi-shift solver */ - double offset[QUDA_MAX_MULTI_SHIFT]; + real_t offset[QUDA_MAX_MULTI_SHIFT]; /** Solver tolerance for each offset */ - double tol_offset[QUDA_MAX_MULTI_SHIFT]; + real_t tol_offset[QUDA_MAX_MULTI_SHIFT]; /** Solver tolerance for each shift when refinement is applied using the heavy-quark residual */ - double tol_hq_offset[QUDA_MAX_MULTI_SHIFT]; + real_t tol_hq_offset[QUDA_MAX_MULTI_SHIFT]; /** Actual L2 residual norm achieved in solver for each offset */ - double true_res_offset[QUDA_MAX_MULTI_SHIFT]; + real_t true_res_offset[QUDA_MAX_MULTI_SHIFT]; /** Iterated L2 residual norm achieved in multi shift solver for each offset */ - double iter_res_offset[QUDA_MAX_MULTI_SHIFT]; + real_t iter_res_offset[QUDA_MAX_MULTI_SHIFT]; /** Actual heavy quark residual norm achieved in solver for each offset */ - double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]; + real_t true_res_hq_offset[QUDA_MAX_MULTI_SHIFT]; /** Number of steps in s-step algorithms */ int Nsteps; @@ -193,31 +193,31 @@ namespace quda { int precondition_cycle; /** Tolerance in the inner solver */ - double tol_precondition; + real_t tol_precondition; /** Maximum number of iterations allowed in the inner solver */ int maxiter_precondition; /** Relaxation parameter used in GCR-DD (default = 1.0) */ - double omega; + real_t omega; /** Basis for CA algorithms */ QudaCABasis ca_basis; /** Minimum eigenvalue for Chebyshev CA basis */ - double ca_lambda_min; + real_t ca_lambda_min; /** Maximum eigenvalue for Chebyshev CA basis */ - double ca_lambda_max; // -1 -> power iter generate + real_t ca_lambda_max; // -1 -> power iter generate /** Basis for CA algorithms in a preconditioner */ QudaCABasis ca_basis_precondition; /** Minimum eigenvalue for Chebyshev CA basis in a preconditioner */ - double ca_lambda_min_precondition; + real_t ca_lambda_min_precondition; /** Maximum eigenvalue for Chebyshev CA basis in a preconditioner */ - double ca_lambda_max_precondition; // -1 -> power iter generate + real_t ca_lambda_max_precondition; // -1 -> power iter generate /** Whether to use additive or multiplicative Schwarz preconditioning */ QudaSchwarzType schwarz_type; @@ -242,8 +242,8 @@ namespace quda { int eigcg_max_restarts; int max_restart_num; - double inc_tol; - double eigenval_tol; + real_t inc_tol; + real_t eigenval_tol; QudaVerbosity verbosity_precondition; //! verbosity to use for preconditioner @@ -463,31 +463,31 @@ namespace quda { @param param the QudaInvertParam to be updated */ void updateInvertParam(QudaInvertParam ¶m, int offset=-1) { - param.true_res = true_res; - param.true_res_hq = true_res_hq; + param.true_res = double(true_res); + param.true_res_hq = double(true_res_hq); param.iter += iter; comm_allreduce_sum(gflops); param.gflops += gflops; param.secs += secs; if (offset >= 0) { - param.true_res_offset[offset] = true_res_offset[offset]; - param.iter_res_offset[offset] = iter_res_offset[offset]; - param.true_res_hq_offset[offset] = true_res_hq_offset[offset]; + param.true_res_offset[offset] = double(true_res_offset[offset]); + param.iter_res_offset[offset] = double(iter_res_offset[offset]); + param.true_res_hq_offset[offset] = double(true_res_hq_offset[offset]); } else { for (int i=0; i(param.eig_param) = eig_param; } @@ -515,7 +515,7 @@ namespace quda { bool deflate_compute; /** If true, instruct the solver to create a deflation space. */ bool recompute_evals; /** If true, instruct the solver to recompute evals from an existing deflation space. */ std::vector evecs; /** Holds the eigenvectors. */ - std::vector evals; /** Holds the eigenvalues. */ + std::vector evals; /** Holds the eigenvalues. */ bool mixed() { return param.precision != param.precision_sloppy; } @@ -553,12 +553,12 @@ namespace quda { @brief a virtual method that performs the inversion and collect some vectors. The default here is a no-op and should not be called. */ - virtual void solve_and_collect(ColorSpinorField &, ColorSpinorField &, cvector_ref &, int, double) + virtual void solve_and_collect(ColorSpinorField &, ColorSpinorField &, cvector_ref &, int, real_t) { errorQuda("NOT implemented."); } - void set_tol(double tol) { param.tol = tol; } + void set_tol(real_t tol) { param.tol = tol; } void set_maxiter(int maxiter) { param.maxiter = maxiter; } const DiracMatrix &M() { return mat; } @@ -636,7 +636,7 @@ namespace quda { @param[in] residual_type The type of residual we want to solve for @return L2 stopping condition */ - static double stopping(double tol, double b2, QudaResidualType residual_type); + static real_t stopping(real_t tol, real_t b2, QudaResidualType residual_type); /** @briefTest for solver convergence @@ -646,7 +646,7 @@ namespace quda { @param[in] hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergence(double r2, double hq2, double r2_tol, double hq_tol); + bool convergence(real_t r2, real_t hq2, real_t r2_tol, real_t hq_tol); /** @brief Test for HQ solver convergence -- ignore L2 residual @@ -656,7 +656,7 @@ namespace quda { @param[in[ hq_tol Solver heavy-quark tolerance @return Whether converged */ - bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceHQ(real_t r2, real_t hq2, real_t r2_tol, real_t hq_tol); /** @brief Test for L2 solver convergence -- ignore HQ residual @@ -665,7 +665,7 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol); + bool convergenceL2(real_t r2, real_t hq2, real_t r2_tol, real_t hq_tol); /** @brief Prints out the running statistics of the solver @@ -675,7 +675,7 @@ namespace quda { @param[in] r2 L2 norm squared of the residual @param[in] hq2 Heavy quark residual */ - void PrintStats(const char *name, int k, double r2, double b2, double hq2); + void PrintStats(const char *name, int k, real_t r2, real_t b2, real_t hq2); /** @brief Prints out the summary of the solver convergence @@ -688,7 +688,7 @@ namespace quda { @param[in] r2_tol Solver L2 tolerance @param[in] hq_tol Solver heavy-quark tolerance */ - void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol); + void PrintSummary(const char *name, int k, real_t r2, real_t b2, real_t r2_tol, real_t hq_tol); /** @brief Returns the epsilon tolerance for a given precision, by default returns @@ -767,7 +767,7 @@ namespace quda { @return Norm of final power iteration result */ template - static double performPowerIterations(const DiracMatrix &diracm, const ColorSpinorField &start, + static real_t performPowerIterations(const DiracMatrix &diracm, const ColorSpinorField &start, ColorSpinorField &tempvec1, ColorSpinorField &tempvec2, int niter, int normalize_freq, Args &&...args); @@ -784,8 +784,8 @@ namespace quda { */ template static void computeCAKrylovSpace(const DiracMatrix &diracm, std::vector &Ap, - std::vector &p, int n_krylov, QudaCABasis basis, double m_map, - double b_map, Args &&...args); + std::vector &p, int n_krylov, QudaCABasis basis, real_t m_map, + real_t b_map, Args &&...args); /** * @brief Return flops @@ -823,7 +823,7 @@ namespace quda { * @param p_init Initial-search direction. * @param r2_old_init [description] */ - void operator()(ColorSpinorField &out, ColorSpinorField &in, ColorSpinorField *p_init, double r2_old_init); + void operator()(ColorSpinorField &out, ColorSpinorField &in, ColorSpinorField *p_init, real_t r2_old_init); void blocksolve(ColorSpinorField &out, ColorSpinorField &in) override; @@ -1055,7 +1055,7 @@ namespace quda { @param collect_tol maxiter tolerance start from which the r vectors are to be collected */ virtual void solve_and_collect(ColorSpinorField &out, ColorSpinorField &in, cvector_ref &v_r, - int collect_miniter, double collect_tol) override; + int collect_miniter, real_t collect_tol) override; virtual bool hermitian() const override { return true; } /** PCG is only Hermitian system */ @@ -1099,10 +1099,10 @@ namespace quda { int pipeline; // pipelining factor for legacyGramSchmidt // Various coefficients and params needed on each iteration. - Complex rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. - std::vector gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. - std::vector tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. - std::vector sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. + complex_t rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L. + std::vector gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length. + std::vector tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length. + std::vector sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length. ColorSpinorField r_full; //! Full precision residual. ColorSpinorField y; //! Full precision temporary. @@ -1125,7 +1125,7 @@ namespace quda { /** @brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation. */ - int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta); + int reliable(real_t &rNorm, real_t &maxrx, real_t &maxrr, const real_t &r2, const real_t &delta); /** * @brief Internal routine for performing the MR part of BiCGstab-L @@ -1203,9 +1203,9 @@ namespace quda { */ int n_krylov; - std::vector alpha; - std::vector beta; - std::vector gamma; + std::vector alpha; + std::vector beta; + std::vector gamma; /** Solver uses lazy allocation: this flag to determine whether we have allocated. @@ -1218,13 +1218,13 @@ namespace quda { std::vector p; // GCR direction vectors std::vector Ap; // mat * direction vectors - void computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k); - void updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k); - void orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline); - void backSubs(const std::vector &alpha, const std::vector &beta, const std::vector &gamma, - std::vector &delta, int n); - void updateSolution(ColorSpinorField &x, const std::vector &alpha, const std::vector &beta, - std::vector &gamma, int k, std::vector &p); + void computeBeta(std::vector &beta, std::vector &Ap, int i, int N, int k); + void updateAp(std::vector &beta, std::vector &Ap, int begin, int size, int k); + void orthoDir(std::vector &beta, std::vector &Ap, int k, int pipeline); + void backSubs(const std::vector &alpha, const std::vector &beta, const std::vector &gamma, + std::vector &delta, int n); + void updateSolution(ColorSpinorField &x, const std::vector &alpha, const std::vector &beta, + std::vector &gamma, int k, std::vector &p); /** @brief Initiate the fields needed by the solver @@ -1298,10 +1298,10 @@ namespace quda { bool lambda_init; QudaCABasis basis; - std::vector Q_AQandg; // Fused inner product matrix - std::vector Q_AS; // inner product matrix - std::vector alpha; // QAQ^{-1} g - std::vector beta; // QAQ^{-1} QpolyS + std::vector Q_AQandg; // Fused inner product matrix + std::vector Q_AS; // inner product matrix + std::vector alpha; // QAQ^{-1} g + std::vector beta; // QAQ^{-1} QpolyS ColorSpinorField r; @@ -1335,7 +1335,7 @@ namespace quda { /** @ brief Check if it's time for a reliable update */ - int reliable(double &rNorm, double &maxrr, int &rUpdate, const double &r2, const double &delta); + int reliable(real_t &rNorm, real_t &maxrr, int &rUpdate, const real_t &r2, const real_t &delta); public: CACG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, @@ -1438,7 +1438,7 @@ namespace quda { bool lambda_init; // whether or not lambda_max has been initialized QudaCABasis basis; // CA basis - std::vector alpha; // Solution coefficient vectors + std::vector alpha; // Solution coefficient vectors ColorSpinorField r; @@ -1459,7 +1459,7 @@ namespace quda { @param[in] q Search direction vectors with the operator applied @param[in] b Source vector against which we are solving */ - void solve(std::vector &psi, std::vector &q, ColorSpinorField &b); + void solve(std::vector &psi, std::vector &q, ColorSpinorField &b); public: CAGCR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig, @@ -1571,7 +1571,7 @@ namespace quda { } virtual void operator()(std::vector &out, ColorSpinorField &in) = 0; - bool convergence(const std::vector &r2, const std::vector &r2_tol, int n) const; + bool convergence(const std::vector &r2, const std::vector &r2_tol, int n) const; }; /** @@ -1603,7 +1603,7 @@ namespace quda { * @param r2_old_array pointer to last values of r2_old for old shifts. Needs to be large enough to hold r2_old for all shifts. */ void operator()(std::vector &x, ColorSpinorField &b, std::vector &p, - std::vector &r2_old_array); + std::vector &r2_old_array); /** * @brief Run multi-shift and return Krylov-space at the end of the solve in p and r2_old_arry. @@ -1613,7 +1613,7 @@ namespace quda { */ void operator()(std::vector &out, ColorSpinorField &in) { - std::vector r2_old(out.size()); + std::vector r2_old(out.size()); std::vector p; (*this)(out, in, p, r2_old); @@ -1649,7 +1649,7 @@ namespace quda { @param[in] q Search direction vectors with the operator applied @param[in] hermitian Whether the linear system is Hermitian or not */ - void solve(std::vector &psi_, std::vector &p, std::vector &q, + void solve(std::vector &psi_, std::vector &p, std::vector &q, const ColorSpinorField &b, bool hermitian); public: @@ -1712,8 +1712,8 @@ namespace quda { */ void increment(ColorSpinorField &V, int n_ev); - void RestartVT(const double beta, const double rho); - void UpdateVm(ColorSpinorField &res, double beta, double sqrtr2); + void RestartVT(const real_t beta, const real_t rho); + void UpdateVm(ColorSpinorField &res, real_t beta, real_t sqrtr2); // EigCG solver: int eigCGsolve(ColorSpinorField &out, ColorSpinorField &in); // InitCG solver: @@ -1762,7 +1762,7 @@ namespace quda { void operator()(ColorSpinorField &out, ColorSpinorField &in); // //GMRESDR method - void RunDeflatedCycles (ColorSpinorField *out, ColorSpinorField *in, const double tol_threshold); + void RunDeflatedCycles (ColorSpinorField *out, ColorSpinorField *in, const real_t tol_threshold); // int FlexArnoldiProcedure (const int start_idx, const bool do_givens); @@ -1782,7 +1782,7 @@ namespace quda { struct deflation_space : public Object { bool svd; /** Whether this space is for an SVD deflaton */ std::vector evecs; /** Container for the eigenvectors */ - std::vector evals; /** The eigenvalues */ + std::vector evals; /** The eigenvalues */ }; /** diff --git a/include/invert_x_update.h b/include/invert_x_update.h index 4e79624871..58ae5510e6 100644 --- a/include/invert_x_update.h +++ b/include/invert_x_update.h @@ -15,7 +15,7 @@ namespace quda int _j; /**< the current index */ int _next; /**< the next index */ std::vector _ps; /**< the container for the p-vectors */ - std::vector _alphas; /**< @param _alphas the alpha's */ + std::vector _alphas; /**< @param _alphas the alpha's */ XUpdateBatch() = default; @@ -49,7 +49,7 @@ namespace quda */ void accumulate_x(ColorSpinorField &x) { - blas::axpy({_alphas.begin(), _alphas.begin() + _j + 1}, {_ps.begin(), _ps.begin() + _j + 1}, x); + blas::axpy({_alphas.begin(), _alphas.begin() + _j + 1}, {_ps.begin(), _ps.begin() + _j + 1}, x); } /** @@ -80,12 +80,12 @@ namespace quda /** @brief Get the current alpha */ - double &get_current_alpha() { return _alphas[_j]; } + real_t &get_current_alpha() { return _alphas[_j]; } /** @brief Get the current alpha */ - const double &get_current_alpha() const { return _alphas[_j]; } + const real_t &get_current_alpha() const { return _alphas[_j]; } /** @brief increase the counter by one (modulo _Np) diff --git a/include/kernels/blas_core.cuh b/include/kernels/blas_core.cuh index 96f0c2b5f1..c0ee41feca 100644 --- a/include/kernels/blas_core.cuh +++ b/include/kernels/blas_core.cuh @@ -99,7 +99,7 @@ namespace quda static constexpr memory_access<0, 0, 0, 0, 1> write{ }; const real a; const real b; - axpbyz_(const real &a, const real &b, const real &) : a(a), b(b) { ; } + axpbyz_(const real_t &a, const real_t &b, const real_t &) : a(a), b(b) { ; } template __device__ __host__ void operator()(T &x, T &y, T &, T &, T &v) const { #pragma unroll @@ -115,7 +115,7 @@ namespace quda static constexpr memory_access<1, 0> read{ }; static constexpr memory_access<0, 1> write{ }; const real a; - axy_(const real &a, const real &, const real &) : a(a) { ; } + axy_(const real_t &a, const real_t &, const real_t &) : a(a) { ; } template __device__ __host__ void operator()(T &x, T &y, T &, T &, T &) const { #pragma unroll @@ -182,7 +182,7 @@ namespace quda const real a; const real b; const real c; - axpbypczw_(const real &a, const real &b, const real &c) : a(a), b(b), c(c) { ; } + axpbypczw_(const real_t &a, const real_t &b, const real_t &c) : a(a), b(b), c(c) { ; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &w, T &) const { #pragma unroll @@ -202,7 +202,7 @@ namespace quda const real a; const real b; const real c; - axpyBzpcx_(const real &a, const real &b, const real &c) : a(a), b(b), c(c) { ; } + axpyBzpcx_(const real_t &a, const real_t &b, const real_t &c) : a(a), b(b), c(c) { ; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &, T &) const { #pragma unroll @@ -222,7 +222,7 @@ namespace quda static constexpr memory_access<1, 1> write{ }; const real a; const real b; - axpyZpbx_(const real &a, const real &b, const real &) : a(a), b(b) { ; } + axpyZpbx_(const real_t &a, const real_t &b, const real_t &) : a(a), b(b) { ; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &, T &) const { #pragma unroll @@ -369,19 +369,19 @@ namespace quda static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<1, 1> write{ }; complex a; - double3 *Ar3; + array *Ar3; bool init_; - caxpyxmazMR_(const real &a, const real &, const real &) : - a(a), - Ar3(static_cast(reducer::get_device_buffer())), + caxpyxmazMR_(const real_t &a, const real_t &, const real_t &) : + a(real(a)), + Ar3(static_cast*>(reducer::get_device_buffer())), init_(false) { ; } __device__ __host__ void init() { if (!init_) { - double3 result = *Ar3; - a = a.real() * complex((real)result.x, (real)result.y) * ((real)1.0 / (real)result.z); + auto result = *Ar3; + a = a.real() * complex((real)result[0], (real)result[1]) / (real)result[2]; init_ = true; } } @@ -409,7 +409,7 @@ namespace quda static constexpr memory_access<0, 1, 1, 1> write{ }; const real a; const real b; - tripleCGUpdate_(const real &a, const real &b, const real &) : a(a), b(b) { ; } + tripleCGUpdate_(const real_t &a, const real_t &b, const real_t &) : a(a), b(b) { ; } template __device__ __host__ void operator()(T &x, T &y, T &z, T &w, T &) const { #pragma unroll diff --git a/include/kernels/clover_compute.cuh b/include/kernels/clover_compute.cuh index 67f5560845..897af593fa 100644 --- a/include/kernels/clover_compute.cuh +++ b/include/kernels/clover_compute.cuh @@ -22,7 +22,7 @@ namespace quda { int X[4]; // grid dimensions real coeff; - CloverArg(CloverField &clover, const GaugeField &f, double coeff) : + CloverArg(CloverField &clover, const GaugeField &f, real_t coeff) : kernel_param(dim3(f.VolumeCB(), 2, 1)), clover(clover, 0), f(f), diff --git a/include/kernels/clover_invert.cuh b/include/kernels/clover_invert.cuh index 42292b7a83..c0dfb00946 100644 --- a/include/kernels/clover_invert.cuh +++ b/include/kernels/clover_invert.cuh @@ -7,7 +7,7 @@ namespace quda { - template struct CloverInvertArg : public ReduceArg> { + template struct CloverInvertArg : public ReduceArg> { using store_t = store_t_; using real = typename mapper::type; static constexpr bool twist = twist_; diff --git a/include/kernels/clover_sigma_outer_product.cuh b/include/kernels/clover_sigma_outer_product.cuh index 91d2de8ef5..039392b2ab 100644 --- a/include/kernels/clover_sigma_outer_product.cuh +++ b/include/kernels/clover_sigma_outer_product.cuh @@ -27,15 +27,15 @@ namespace quda CloverSigmaOprodArg(GaugeField &oprod, const std::vector &inA, const std::vector &inB, - const std::vector> &coeff_) : + const std::vector> &coeff_) : kernel_param(dim3(oprod.VolumeCB(), 2, 6)), oprod(oprod), inA{*inA[0]}, inB{*inB[0]} { for (int i = 0; i < nvector; i++) { - coeff[i][0] = coeff_[i][0]; - coeff[i][1] = coeff_[i][1]; + coeff[i][0] = real(coeff_[i][0]); + coeff[i][1] = real(coeff_[i][1]); } } }; diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index b63bf7f435..df66698cca 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -164,7 +164,7 @@ namespace quda { CalculateYArg(coarseGauge &Y, coarseGauge &X, coarseGaugeAtomic &Y_atomic, coarseGaugeAtomic &X_atomic, fineSpinorUV &UV, fineSpinorAV &AV, const fineGauge &U, const fineGauge &L, const fineGauge &K, const fineSpinorV &V, - const fineClover &C, const fineClover &Cinv, const ColorSpinorField &v, double kappa, double mass, double mu, double mu_factor, + const fineClover &C, const fineClover &Cinv, const ColorSpinorField &v, real_t kappa, real_t mass, real_t mu, real_t mu_factor, const int *x_size_, const int *xc_size_, int spin_bs_, const int *fine_to_coarse, const int *coarse_to_fine, bool bidirectional) : Y(Y), X(X), Y_atomic(Y_atomic), X_atomic(X_atomic), diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index 00b61a9b3e..03685b9f62 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -43,7 +43,7 @@ namespace quda { QudaTwistGamma5Type twist; CloverArg(ColorSpinorField &out, const ColorSpinorField &in, const CloverField &clover, - int parity, real kappa=0.0, real mu=0.0, real epsilon = 0.0, + int parity, real_t kappa=0.0, real_t mu=0.0, real_t epsilon = 0.0, bool dagger = false, QudaTwistGamma5Type twist = QUDA_TWIST_GAMMA5_INVALID) : kernel_param(dim3(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET ? in.VolumeCB()/2 : in.VolumeCB(), in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET ? 2 : 1, in.SiteSubset())), @@ -58,20 +58,20 @@ namespace quda { checkLocation(out, in, clover); if (in.TwistFlavor() == QUDA_TWIST_SINGLET) { if (twist == QUDA_TWIST_GAMMA5_DIRECT) { - a = 2.0 * kappa * mu; + a = real(2.0 * kappa * mu); b = 1.0; } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { - a = -2.0 * kappa * mu; + a = -real(2.0 * kappa * mu); b = 1.0 / (1.0 + a*a); } if (dagger) a *= -1.0; } else if (doublet) { if (twist == QUDA_TWIST_GAMMA5_DIRECT){ - a = 2.0 * kappa * mu; - b = -2.0 * kappa * epsilon; + a = real(2.0 * kappa * mu); + b = -real(2.0 * kappa * epsilon); } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { - a = -2.0 * kappa * mu; - b = 2.0 * kappa * epsilon; + a = -real(2.0 * kappa * mu); + b = real(2.0 * kappa * epsilon); } if (dagger) a *= -1.0; } diff --git a/include/kernels/dslash_domain_wall_4d.cuh b/include/kernels/dslash_domain_wall_4d.cuh index 09a09a9f3e..f6f961a9d0 100644 --- a/include/kernels/dslash_domain_wall_4d.cuh +++ b/include/kernels/dslash_domain_wall_4d.cuh @@ -11,16 +11,16 @@ namespace quda int Ls; /** fifth dimension length */ complex a_5[QUDA_MAX_DWF_LS]; /** xpay scale factor for each 4-d subvolume */ - DomainWall4DArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, - const Complex *b_5, const Complex *c_5, bool xpay, const ColorSpinorField &x, int parity, + DomainWall4DArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t m_5, + const complex_t *b_5, const complex_t *c_5, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : - WilsonArg(out, in, U, xpay ? a : 0.0, x, parity, dagger, comm_override), + WilsonArg(out, in, U, xpay ? a : real_t(0.0), x, parity, dagger, comm_override), Ls(in.X(4)) { if (b_5 == nullptr || c_5 == nullptr) - for (int s = 0; s < Ls; s++) a_5[s] = a; // 4-d Shamir + for (int s = 0; s < Ls; s++) a_5[s] = real(a); // 4-d Shamir else - for (int s = 0; s < Ls; s++) a_5[s] = 0.5 * a / (b_5[s] * (m_5 + 4.0) + 1.0); // 4-d Mobius + for (int s = 0; s < Ls; s++) a_5[s] = complex(real_t(0.5) * a / (b_5[s] * (m_5 + real_t(4.0)) + real_t(1.0))); // 4-d Mobius } }; diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index dfcd519214..7e3614976c 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -40,16 +40,16 @@ namespace quda bool fuse_m5inv_m5pre; - DomainWall4DFusedM5Arg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_5, - const Complex *b_5, const Complex *c_5, bool xpay, const ColorSpinorField &x, - ColorSpinorField &y, int parity, bool dagger, const int *comm_override, double m_f) : + DomainWall4DFusedM5Arg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t m_5, + const complex_t *b_5, const complex_t *c_5, bool xpay, const ColorSpinorField &x, + ColorSpinorField &y, int parity, bool dagger, const int *comm_override, real_t m_f) : DomainWall4DArg(out, in, U, a, m_5, b_5, c_5, xpay, x, parity, dagger, comm_override), Dslash5Arg(out, in, x, m_f, m_5, b_5, c_5, a), y(y) { for (int s = 0; s < Ls; s++) { - auto kappa_b_s = 0.5 / (b_5[s] * (m_5 + 4.0) + 1.0); - a_5[s] = a * kappa_b_s * kappa_b_s; + auto kappa_b_s = real_t(0.5) / (b_5[s] * (real_t(m_5) + real_t(4.0)) + real_t(1.0)); + a_5[s] = static_cast>(a * kappa_b_s * kappa_b_s); }; // 4-d Mobius } }; diff --git a/include/kernels/dslash_domain_wall_5d.cuh b/include/kernels/dslash_domain_wall_5d.cuh index c9c344df3d..624e2e6360 100644 --- a/include/kernels/dslash_domain_wall_5d.cuh +++ b/include/kernels/dslash_domain_wall_5d.cuh @@ -14,7 +14,7 @@ namespace quda real a; /** xpay scale factor */ real m_f; /** fermion mass parameter */ - DomainWall5DArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double m_f, + DomainWall5DArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t m_f, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, xpay ? a : 0.0, x, parity, dagger, comm_override), Ls(in.X(4)), diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index bab21d4c11..8c7b03a86c 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -110,25 +110,25 @@ namespace quda coeff_5 coeff; // constant buffer used for Mobius coefficients for CPU kernel - void compute_coeff_mobius_pre(const Complex *b_5, const Complex *c_5) + void compute_coeff_mobius_pre(const complex_t *b_5, const complex_t *c_5) { // out = (b + c * D5) * in for (int s = 0; s < Ls; s++) { coeff.beta[s] = b_5[s]; - coeff.alpha[s] = 0.5 * c_5[s]; // 0.5 from gamma matrices + coeff.alpha[s] = 0.5 * complex(c_5[s]); // 0.5 from gamma matrices // xpay - coeff.a[s] = 0.5 / (b_5[s] * (m_5 + 4.0) + 1.0); + coeff.a[s] = 0.5 / (complex(b_5[s]) * (m_5 + 4.0) + 1.0); coeff.a[s] *= coeff.a[s] * static_cast(a); // kappa_b * kappa_b * a } } - void compute_coeff_mobius(const Complex *b_5, const Complex *c_5) + void compute_coeff_mobius(const complex_t *b_5, const complex_t *c_5) { // out = (1 + kappa * D5) * in for (int s = 0; s < Ls; s++) { - coeff.kappa[s] = 0.5 * (c_5[s] * (m_5 + 4.0) - 1.0) / (b_5[s] * (m_5 + 4.0) + 1.0); // 0.5 from gamma matrices + coeff.kappa[s] = 0.5 * (complex(c_5[s]) * (m_5 + 4.0) - 1.0) / (complex(b_5[s]) * (m_5 + 4.0) + 1.0); // 0.5 from gamma matrices // axpy - coeff.a[s] = 0.5 / (b_5[s] * (m_5 + 4.0) + 1.0); + coeff.a[s] = 0.5 / (complex(b_5[s]) * (m_5 + 4.0) + 1.0); coeff.a[s] *= coeff.a[s] * static_cast(a); // kappa_b * kappa_b * a } } @@ -139,33 +139,33 @@ namespace quda inv = 0.5 / (1.0 + std::pow(kappa, (int)Ls) * m_f); } - void compute_coeff_m5inv_mobius(const Complex *b_5, const Complex *c_5) + void compute_coeff_m5inv_mobius(const complex_t *b_5, const complex_t *c_5) { // out = (1 + kappa * D5)^-1 * in = M5inv * in - kappa = -(c_5[0].real() * (4.0 + m_5) - 1.0) / (b_5[0].real() * (4.0 + m_5) + 1.0); // kappa = kappa_b / kappa_c + kappa = -(double(c_5[0].real()) * (4.0 + m_5) - 1.0) / (double(b_5[0].real()) * (4.0 + m_5) + 1.0); // kappa = kappa_b / kappa_c inv = 0.5 / (1.0 + std::pow(kappa, (int)Ls) * m_f); // 0.5 from gamma matrices - a *= pow(0.5 / (b_5[0].real() * (m_5 + 4.0) + 1.0), 2); // kappa_b * kappa_b * a + a *= pow(0.5 / (double(b_5[0].real()) * (m_5 + 4.0) + 1.0), 2); // kappa_b * kappa_b * a } - void compute_coeff_m5inv_zmobius(const Complex *b_5, const Complex *c_5) + void compute_coeff_m5inv_zmobius(const complex_t *b_5, const complex_t *c_5) { // out = (1 + kappa * D5)^-1 * in = M5inv * in // Similar to mobius convention, but variadic across 5th dim complex k = 1.0; for (int s = 0; s < Ls; s++) { - coeff.kappa[s] = -(c_5[s] * (4.0 + m_5) - 1.0) / (b_5[s] * (4.0 + m_5) + 1.0); + coeff.kappa[s] = -(complex(c_5[s]) * (4.0 + m_5) - 1.0) / (complex(b_5[s]) * (4.0 + m_5) + 1.0); k *= coeff.kappa[s]; } coeff.inv = static_cast(0.5) / (static_cast(1.0) + k * m_f); for (int s = 0; s < Ls; s++) { // axpy coefficients - coeff.a[s] = 0.5 / (b_5[s] * (m_5 + 4.0) + 1.0); + coeff.a[s] = 0.5 / (complex(b_5[s]) * (m_5 + 4.0) + 1.0); coeff.a[s] *= coeff.a[s] * static_cast(a); } } - Dslash5Arg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, double m_f, double m_5, - const Complex *b_5, const Complex *c_5, double a_) : + Dslash5Arg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, real_t m_f, real_t m_5, + const complex_t *b_5, const complex_t *c_5, real_t a_) : kernel_param(dim3(in.VolumeCB() / in.X(4), in.X(4), in.SiteSubset())), out(out), in(in), diff --git a/include/kernels/dslash_gamma_helper.cuh b/include/kernels/dslash_gamma_helper.cuh index 3b5e27492a..041cba0586 100644 --- a/include/kernels/dslash_gamma_helper.cuh +++ b/include/kernels/dslash_gamma_helper.cuh @@ -27,7 +27,7 @@ namespace quda { real c; // flavor twist GammaArg(ColorSpinorField &out, const ColorSpinorField &in, int d, - real kappa=0.0, real mu=0.0, real epsilon=0.0, + real_t kappa=0.0, real_t mu=0.0, real_t epsilon=0.0, bool dagger=false, QudaTwistGamma5Type twist=QUDA_TWIST_GAMMA5_INVALID) : out(out), in(in), d(d), nParity(in.SiteSubset()), doublet(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET), @@ -41,24 +41,24 @@ namespace quda { if (in.TwistFlavor() == QUDA_TWIST_SINGLET) { if (twist == QUDA_TWIST_GAMMA5_DIRECT) { - b = 2.0 * kappa * mu; + b = real(2.0 * kappa * mu); a = 1.0; } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { - b = -2.0 * kappa * mu; + b = -real(2.0 * kappa * mu); a = 1.0 / (1.0 + b * b); } c = 0.0; if (dagger) b *= -1.0; } else if (doublet) { if (twist == QUDA_TWIST_GAMMA5_DIRECT) { - b = 2.0 * kappa * mu; - c = -2.0 * kappa * epsilon; + b = real(2.0 * kappa * mu); + c = -real(2.0 * kappa * epsilon); a = 1.0; } else if (twist == QUDA_TWIST_GAMMA5_INVERSE) { - b = -2.0 * kappa * mu; - c = 2.0 * kappa * epsilon; + b = -real(2.0 * kappa * mu); + c = real(2.0 * kappa * epsilon); a = 1.0 / (1.0 + b * b - c * c); - if (a <= 0) errorQuda("Invalid twisted mass parameters (kappa=%e, mu=%e, epsilon=%e)\n", kappa, mu, epsilon); + if (a <= 0) errorQuda("Invalid twisted mass parameters (kappa=%e, mu=%e, epsilon=%e)\n", double(kappa), double(mu), double(epsilon)); } if (dagger) b *= -1.0; } diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 67f98b30cc..be3b6e31e7 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -87,7 +87,7 @@ namespace quda { const bool comm[4]; FusedDslashArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, ColorSpinorField &y, - const ColorSpinorField &x, double m_f_, double m_5_, const Complex *b_5, const Complex *c_5, + const ColorSpinorField &x, real_t m_f_, real_t m_5_, const complex_t *b_5, const complex_t *c_5, int parity, int shift_[4], int halo_shift_[4]) : out(out), in(in), @@ -116,8 +116,8 @@ namespace quda { if (b_5[0] != b_5[1] || b_5[0].imag() != 0) { errorQuda("zMobius is NOT supported yet.\n"); } - b = b_5[0].real(); - c = c_5[0].real(); + b = double(b_5[0].real()); + c = double(c_5[0].real()); kappa = -(c * (4. + m_5) - 1.) / (b * (4. + m_5) + 1.); // This is actually -kappa in my(Jiqun Tu) notes. if (kappa * kappa < 1e-6) { small_kappa = true; } diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index f5e0a5c8ac..fceca25c70 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -52,9 +52,9 @@ namespace quda eofa_coeff coeff; - Dslash5Arg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const double m_f_, - const double m_5_, const Complex */*b_5_*/, const Complex */*c_5_*/, double a_, double inv_, double kappa_, - const double *eofa_u, const double *eofa_x, const double *eofa_y, double sherman_morrison_) : + Dslash5Arg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &x, const real_t m_f_, + const real_t m_5_, const complex_t */*b_5_*/, const complex_t */*c_5_*/, real_t a_, real_t inv_, real_t kappa_, + const real_t *eofa_u, const real_t *eofa_x, const real_t *eofa_y, real_t sherman_morrison_) : kernel_param(dim3(in.VolumeCB() / in.X(4), in.X(4), in.SiteSubset())), out(out), in(in), @@ -76,13 +76,13 @@ namespace quda switch (type) { case Dslash5Type::M5_EOFA: - for (int s = 0; s < Ls; s++) { coeff.u[s] = eofa_u[s]; } + for (int s = 0; s < Ls; s++) { coeff.u[s] = real(eofa_u[s]); } break; case Dslash5Type::M5INV_EOFA: for (int s = 0; s < Ls; s++) { - coeff.u[s] = eofa_u[s]; - coeff.x[s] = eofa_x[s]; - coeff.y[s] = eofa_y[s]; + coeff.u[s] = real(eofa_u[s]); + coeff.x[s] = real(eofa_x[s]); + coeff.y[s] = real(eofa_y[s]); } break; default: errorQuda("Unexpected EOFA Dslash5Type %d", static_cast(type)); diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index 108f8c5e84..b052c90076 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -21,8 +21,8 @@ namespace quda real c; /** this is the flavor twist factor */ NdegTwistedCloverArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A, double a, double b, - double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + const CloverField &A, real_t a, real_t b, + real_t c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), A(A, false), a(a), diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index bdbff30817..6f970558ce 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -25,7 +25,7 @@ namespace quda NdegTwistedCloverPreconditionedArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, - double a, double b, double c, bool xpay, + real_t a, real_t b, real_t c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, xpay ? 1.0 : 0.0, x, parity, dagger, comm_override), diff --git a/include/kernels/dslash_ndeg_twisted_mass.cuh b/include/kernels/dslash_ndeg_twisted_mass.cuh index 2f88d5ee48..fd6ee8b3f6 100644 --- a/include/kernels/dslash_ndeg_twisted_mass.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass.cuh @@ -12,8 +12,8 @@ namespace quda real b; /** this is the chiral twist factor */ real c; /** this is the flavor twist factor */ - NdegTwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, - double c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + NdegTwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, + real_t c, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), a(a), b(dagger ? -b : b), // if dagger flip the chiral twist diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 98e72eb61a..16d052e895 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -18,8 +18,8 @@ namespace quda real b_inv; /** inverse chiral twist factor - used to allow early xpay inclusion */ real c_inv; /** inverse flavor twist factor - used to allow early xpay inclusion */ - NdegTwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, - double c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + NdegTwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, + real_t c, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, xpay ? 1.0 : 0.0, x, parity, dagger, comm_override), a(a), b(dagger ? -b : b), // if dagger flip the chiral twist diff --git a/include/kernels/dslash_pack.cuh b/include/kernels/dslash_pack.cuh index b4c22a4bb9..dfaf6ad81a 100644 --- a/include/kernels/dslash_pack.cuh +++ b/include/kernels/dslash_pack.cuh @@ -71,8 +71,8 @@ namespace quda #else static constexpr int shmem = 0; #endif - PackArg(void **ghost, const ColorSpinorField &in, int nFace, int parity, int work_items, double a, double b, - double c, unsigned int block, unsigned int grid, + PackArg(void **ghost, const ColorSpinorField &in, int nFace, int parity, int work_items, real_t a, real_t b, + real_t c, unsigned int block, unsigned int grid, #ifdef NVSHMEM_COMMS int shmem_) : #else diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index deb38455f8..a2fae64ef9 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -49,7 +49,7 @@ namespace quda const real dagger_scale; - StaggeredArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, double a, + StaggeredArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const GaugeField &L, real_t a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : DslashArg(out, in, U, x, parity, dagger, a == 0.0 ? false : true, improved_ ? 3 : 1, spin_project, comm_override), diff --git a/include/kernels/dslash_twisted_clover_preconditioned.cuh b/include/kernels/dslash_twisted_clover_preconditioned.cuh index bc46d106f4..9284c20426 100644 --- a/include/kernels/dslash_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_twisted_clover_preconditioned.cuh @@ -22,7 +22,7 @@ namespace quda real b2; TwistedCloverArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, - double a, double b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, + real_t a, real_t b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, xpay ? 1.0 : 0.0, x, parity, dagger, comm_override), A(A, false), diff --git a/include/kernels/dslash_twisted_mass.cuh b/include/kernels/dslash_twisted_mass.cuh index 85230c729d..b47a0fff43 100644 --- a/include/kernels/dslash_twisted_mass.cuh +++ b/include/kernels/dslash_twisted_mass.cuh @@ -11,7 +11,7 @@ namespace quda real a; /** xpay scale facotor */ real b; /** this is the twist factor */ - TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, + TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), a(a), diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 496e12276d..bf62192548 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -15,7 +15,7 @@ namespace quda real a_inv; /** inverse scaling factor - used to allow early xpay inclusion */ real b_inv; /** inverse twist factor - used to allow early xpay inclusion */ - TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, double b, bool xpay, + TwistedMassArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, real_t b, bool xpay, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, xpay ? 1.0 : 0.0, x, parity, dagger, comm_override), a(a), diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index f87e8f9865..7720a21bda 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -36,9 +36,9 @@ namespace quda const G U; /** the gauge field */ const real a; /** xpay scale factor - can be -kappa or -kappa^2 */ - WilsonArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, double a, + WilsonArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, real_t a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : - DslashArg(out, in, U, x, parity, dagger, a != 0.0 ? true : false, 1, spin_project, comm_override), + DslashArg(out, in, U, x, parity, dagger, a != real_t(0.0) ? true : false, 1, spin_project, comm_override), out(out), in(in), in_pack(in), diff --git a/include/kernels/dslash_wilson_clover.cuh b/include/kernels/dslash_wilson_clover.cuh index ffc8b880e8..09036b0405 100644 --- a/include/kernels/dslash_wilson_clover.cuh +++ b/include/kernels/dslash_wilson_clover.cuh @@ -21,7 +21,7 @@ namespace quda const real b; /** chiral twist factor (twisted-clover only) */ WilsonCloverArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, - double a, double b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), A(A, false), a(a), diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh index e6923bf707..80946efbd1 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh @@ -20,7 +20,7 @@ namespace quda const real b; /** chiral twist factor (twisted-clover only) */ WilsonCloverHasenbuschTwistArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A, double a, double b, const ColorSpinorField &x, int parity, + const CloverField &A, real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), A(A, false), diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh index b9a1599e94..66b4baa37b 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh @@ -22,7 +22,7 @@ namespace quda real b; WilsonCloverHasenbuschTwistPCArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, - const CloverField &A_, double a_, double b_, const ColorSpinorField &x, int parity, + const CloverField &A_, real_t a_, real_t b_, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a_, x, parity, dagger, comm_override), A(A_, false), diff --git a/include/kernels/dslash_wilson_clover_preconditioned.cuh b/include/kernels/dslash_wilson_clover_preconditioned.cuh index 9adbc64343..ca73f5b2b3 100644 --- a/include/kernels/dslash_wilson_clover_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_preconditioned.cuh @@ -20,7 +20,7 @@ namespace quda const real a; /** xpay scale factor */ WilsonCloverArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, const CloverField &A, - double a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : + real_t a, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : WilsonArg(out, in, U, a, x, parity, dagger, comm_override), A(A, dynamic_clover ? false : true), // if dynamic clover we don't want the inverse field a(a) diff --git a/include/kernels/gauge_det_trace.cuh b/include/kernels/gauge_det_trace.cuh index bc4fc5d5fa..d40f028a95 100644 --- a/include/kernels/gauge_det_trace.cuh +++ b/include/kernels/gauge_det_trace.cuh @@ -9,7 +9,7 @@ namespace quda { enum struct compute_type { determinant, trace }; template - struct KernelArg : public ReduceArg> { + struct KernelArg : public ReduceArg> { static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; static constexpr compute_type type = type_; @@ -52,14 +52,15 @@ namespace quda { x[dr] += arg.border[dr]; X[dr] += 2*arg.border[dr]; } + + complex local = {}; #pragma unroll for (int mu = 0; mu < 4; mu++) { Matrix, Arg::nColor> U = arg.u(mu, linkIndex(x, X), parity); - auto local = Arg::type == compute_type::determinant ? getDeterminant(U) : getTrace(U); - value = operator()(value, reduce_t{local.real(), local.imag()}); + local += Arg::type == compute_type::determinant ? getDeterminant(U) : getTrace(U); } - return value; + return operator()(value, array{local.real(), local.imag()}); } }; diff --git a/include/kernels/gauge_fix_fft.cuh b/include/kernels/gauge_fix_fft.cuh index 06b0f89b0f..c0c86088ba 100644 --- a/include/kernels/gauge_fix_fft.cuh +++ b/include/kernels/gauge_fix_fft.cuh @@ -83,7 +83,7 @@ namespace quda { Float alpha; int volume; - GaugeFixArg(GaugeField &data, double alpha) : + GaugeFixArg(GaugeField &data, real_t alpha) : kernel_param(dim3(data.VolumeCB(), 2, 1)), data(data), alpha(static_cast(alpha)), @@ -150,7 +150,7 @@ namespace quda { * @brief container to pass parameters for the gauge fixing quality kernel */ template - struct GaugeFixQualityFFTArg : public ReduceArg> { + struct GaugeFixQualityFFTArg : public ReduceArg> { using real = typename mapper::type; static constexpr QudaReconstructType recon = recon_; using Gauge = typename gauge_mapper::type; @@ -159,7 +159,7 @@ namespace quda { int_fastdiv X[4]; // grid dimensions Gauge data; complex *delta; - reduce_t result; + array result; int volume; GaugeFixQualityFFTArg(const GaugeField &data, complex *delta) : @@ -172,8 +172,8 @@ namespace quda { for (int dir = 0; dir < 4; dir++) X[dir] = data.X()[dir]; } - double getAction() { return result[0]; } - double getTheta() { return result[1]; } + real_t getAction() { return result[0]; } + real_t getTheta() { return result[1]; } }; template struct FixQualityFFT : plus { @@ -189,7 +189,7 @@ namespace quda { */ __device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity) { - reduce_t data{0, 0}; + array data = {}; using matrix = Matrix, 3>; int x[4]; getCoords(x, x_cb, arg.X, parity); @@ -227,7 +227,7 @@ namespace quda { //35 //T=36*gauge_dir+65 - return operator()(data, value); + return operator()(value, data); } }; diff --git a/include/kernels/gauge_fix_ovr.cuh b/include/kernels/gauge_fix_ovr.cuh index 644d9c2463..a678d422ec 100644 --- a/include/kernels/gauge_fix_ovr.cuh +++ b/include/kernels/gauge_fix_ovr.cuh @@ -14,7 +14,7 @@ namespace quda { * @brief container to pass parameters for the gauge fixing quality kernel */ template - struct GaugeFixQualityOVRArg : public ReduceArg> { + struct GaugeFixQualityOVRArg : public ReduceArg> { using real = typename mapper::type; static constexpr QudaReconstructType recon = recon_; using Gauge = typename gauge_mapper::type; @@ -23,7 +23,7 @@ namespace quda { int X[4]; // grid dimensions int border[4]; Gauge data; - reduce_t result; + array result; GaugeFixQualityOVRArg(const GaugeField &data) : ReduceArg(dim3(data.LocalVolumeCB(), 2, 1), 1, true), // reset = true @@ -36,8 +36,8 @@ namespace quda { } } - double getAction(){ return result[0]; } - double getTheta(){ return result[1]; } + auto getAction(){ return result[0]; } + auto getTheta(){ return result[1]; } }; template struct FixQualityOVR : plus { @@ -53,7 +53,7 @@ namespace quda { */ __device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity) { - reduce_t data{0, 0}; + array data = {}; using Link = Matrix, 3>; int X[4]; @@ -93,7 +93,7 @@ namespace quda { //35 //T=36*gauge_dir+65 - return operator()(data, value); + return operator()(value, data); } }; @@ -113,7 +113,7 @@ namespace quda { int border[4]; int *borderpoints[2]; - GaugeFixArg(const GaugeField &u, const double relax_boost, int parity, int *borderpoints[2], unsigned threads) : + GaugeFixArg(const GaugeField &u, const real_t relax_boost, int parity, int *borderpoints[2], unsigned threads) : kernel_param(dim3(threads, type < 3 ? 8 : 4, 1)), u(u), relax_boost(static_cast(relax_boost)), diff --git a/include/kernels/gauge_force.cuh b/include/kernels/gauge_force.cuh index 3647efdcf8..d749513af9 100644 --- a/include/kernels/gauge_force.cuh +++ b/include/kernels/gauge_force.cuh @@ -29,7 +29,7 @@ namespace quda { real epsilon; // stepsize and any other overall scaling factor const paths<4> p; - GaugeForceArg(GaugeField &mom, const GaugeField &u, double epsilon, const paths<4> &p) : + GaugeForceArg(GaugeField &mom, const GaugeField &u, real_t epsilon, const paths<4> &p) : kernel_param(dim3(mom.VolumeCB(), 2, 4)), mom(mom), u(u), diff --git a/include/kernels/gauge_loop_trace.cuh b/include/kernels/gauge_loop_trace.cuh index 9c6b6cc3f8..35a17a6f53 100644 --- a/include/kernels/gauge_loop_trace.cuh +++ b/include/kernels/gauge_loop_trace.cuh @@ -18,9 +18,8 @@ namespace quda { constexpr unsigned int max_n_batch_block_loop_trace() { return 8; } template - struct GaugeLoopTraceArg : public ReduceArg> { + struct GaugeLoopTraceArg : public ReduceArg> { using real = typename mapper::type; - using reduce_t = array; static constexpr unsigned int max_n_batch_block = max_n_batch_block_loop_trace(); static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; @@ -30,7 +29,7 @@ namespace quda { const Gauge u; - const double factor; // overall scaling factor for all loops + const real factor; // overall scaling factor for all loops static constexpr int nParity = 2; // always true for gauge fields int X[4]; // the regular volume parameters int E[4]; // the extended volume parameters @@ -38,7 +37,7 @@ namespace quda { const paths<1> p; - GaugeLoopTraceArg(const GaugeField &u, double factor, const paths<1> &p) : + GaugeLoopTraceArg(const GaugeField &u, real_t factor, const paths<1> &p) : ReduceArg(dim3(u.LocalVolumeCB(), 2, p.num_paths), p.num_paths), u(u), factor(factor), @@ -65,16 +64,14 @@ namespace quda { { using Link = typename Arg::Link; - reduce_t loop_trace{0, 0}; - int x[4] = {0, 0, 0, 0}; getCoords(x, x_cb, arg.X, parity); for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates thread_array dx{0}; - double coeff_loop = arg.factor * arg.p.path_coeff[path_id]; - if (coeff_loop == 0) return operator()(loop_trace, value); + auto coeff_loop = arg.factor * arg.p.path_coeff[path_id]; + if (coeff_loop == 0) return value; const int* path = arg.p.input_path[0] + path_id * arg.p.max_length; @@ -84,10 +81,8 @@ namespace quda { // compute trace auto trace = getTrace(link_prod); - loop_trace[0] = coeff_loop * trace.real(); - loop_trace[1] = coeff_loop * trace.imag(); - - return operator()(loop_trace, value); + array loop_trace = {coeff_loop * trace.real(), coeff_loop * trace.imag()}; + return operator()(value, loop_trace); } }; diff --git a/include/kernels/gauge_phase.cuh b/include/kernels/gauge_phase.cuh index ef57369cb8..03c674e08d 100644 --- a/include/kernels/gauge_phase.cuh +++ b/include/kernels/gauge_phase.cuh @@ -29,8 +29,8 @@ namespace quda { // else we are applying them Float dir = u.StaggeredPhaseApplied() ? -1.0 : 1.0; - i_mu_phase = complex( cos(M_PI * u.iMu() / (u.X()[3]*comm_dim(3)) ), - dir * sin(M_PI * u.iMu() / (u.X()[3]*comm_dim(3))) ); + i_mu_phase = complex( cos(M_PI * double(u.iMu()) / (u.X()[3]*comm_dim(3)) ), + dir * sin(M_PI * double(u.iMu()) / (u.X()[3]*comm_dim(3))) ); for (int d=0; d<4; d++) X[d] = u.X()[d]; diff --git a/include/kernels/gauge_plaq.cuh b/include/kernels/gauge_plaq.cuh index 563d674000..0454817f9f 100644 --- a/include/kernels/gauge_plaq.cuh +++ b/include/kernels/gauge_plaq.cuh @@ -9,7 +9,7 @@ namespace quda { template - struct GaugePlaqArg : public ReduceArg> { + struct GaugePlaqArg : public ReduceArg> { using Float = Float_; static constexpr int nColor = nColor_; static_assert(nColor == 3, "Only nColor=3 enabled at this time"); @@ -62,7 +62,7 @@ namespace quda { // return the plaquette at site (x_cb, parity) __device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity) { - reduce_t plaq{0, 0}; + array plaq {}; int x[4]; getCoords(x, x_cb, arg.X, parity); @@ -79,7 +79,7 @@ namespace quda { plaq[1] += plaquette(arg, x, parity, mu, 3); } - return operator()(plaq, value); + return operator()(value, plaq); } }; diff --git a/include/kernels/gauge_polyakov_loop.cuh b/include/kernels/gauge_polyakov_loop.cuh index 5ea4bc5219..06e6505526 100644 --- a/include/kernels/gauge_polyakov_loop.cuh +++ b/include/kernels/gauge_polyakov_loop.cuh @@ -166,7 +166,7 @@ namespace quda { }; template - struct GaugePolyakovLoopTraceArg : public ReduceArg> { + struct GaugePolyakovLoopTraceArg : public ReduceArg> { using real = typename mapper::type; static constexpr int nColor = nColor_; static_assert(nColor == 3, "Only nColor=3 enabled at this time"); @@ -212,7 +212,7 @@ namespace quda { using HighPrecLink = typename Arg::HighPrecLink; HighPrecLink polyloop; - reduce_t ploop{0, 0}; + array ploop = {}; int x[4]; getCoords(x, x_cb, arg.X, parity); @@ -235,7 +235,7 @@ namespace quda { ploop[0] = tr.real(); ploop[1] = tr.imag(); - return operator()(ploop, value); + return operator()(value, ploop); } }; diff --git a/include/kernels/gauge_qcharge.cuh b/include/kernels/gauge_qcharge.cuh index fadfd45321..638764c392 100644 --- a/include/kernels/gauge_qcharge.cuh +++ b/include/kernels/gauge_qcharge.cuh @@ -7,7 +7,7 @@ namespace quda { template struct QChargeArg : - public ReduceArg> + public ReduceArg> { using Float = Float_; static constexpr int nColor = nColor_; @@ -42,8 +42,8 @@ namespace quda constexpr real q_norm = static_cast(-1.0 / (4*M_PI*M_PI)); constexpr real n_inv = static_cast(1.0 / Arg::nColor); - reduce_t E_local{0, 0, 0}; - double &Q = E_local[2]; + array E_local{0, 0, 0}; + auto &Q = E_local[2]; // Load the field-strength tensor from global memory //F0 = F[Y,X], F1 = F[Z,X], F2 = F[Z,Y], diff --git a/include/kernels/gauge_random.cuh b/include/kernels/gauge_random.cuh index 3c348d73b3..0e7b191d9d 100644 --- a/include/kernels/gauge_random.cuh +++ b/include/kernels/gauge_random.cuh @@ -23,7 +23,7 @@ namespace quda { RNGState *rng; real sigma; // where U = exp(sigma * H) - GaugeGaussArg(const GaugeField &U, RNGState *rng, double sigma) : + GaugeGaussArg(const GaugeField &U, RNGState *rng, real_t sigma) : kernel_param(dim3(U.LocalVolumeCB(), 2, 1)), U(U), rng(rng), diff --git a/include/kernels/laplace.cuh b/include/kernels/laplace.cuh index a029242210..21693985f4 100644 --- a/include/kernels/laplace.cuh +++ b/include/kernels/laplace.cuh @@ -38,7 +38,7 @@ namespace quda const real b; /** used by Wuppetal smearing kernel */ int dir; /** The direction from which to omit the derivative */ - LaplaceArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, double a, double b, + LaplaceArg(ColorSpinorField &out, const ColorSpinorField &in, const GaugeField &U, int dir, real_t a, real_t b, const ColorSpinorField &x, int parity, bool dagger, const int *comm_override) : DslashArg(out, in, U, x, parity, dagger, a != 0.0 ? true : false, 1, false, comm_override), out(out), diff --git a/include/kernels/momentum.cuh b/include/kernels/momentum.cuh index 41f8bc929c..daa8d669de 100644 --- a/include/kernels/momentum.cuh +++ b/include/kernels/momentum.cuh @@ -7,14 +7,14 @@ namespace quda { template - struct MomActionArg : ReduceArg { + struct MomActionArg : ReduceArg { using Float = Float_; static constexpr int nColor = nColor_; static constexpr QudaReconstructType recon = recon_; const typename gauge_mapper::type mom; MomActionArg(const GaugeField &mom) : - ReduceArg(dim3(mom.VolumeCB(), 2, 1)), + ReduceArg(dim3(mom.VolumeCB(), 2, 1)), mom(mom) { } }; @@ -56,7 +56,7 @@ namespace quda { }; template - struct UpdateMomArg : ReduceArg> + struct UpdateMomArg : ReduceArg> { using Float = Float_; static constexpr int nColor = nColor_; @@ -109,7 +109,7 @@ namespace quda { makeAntiHerm(f); // compute force norms - norm = operator()(reduce_t{f.L1(), f.L2()}, norm); + norm = operator()(norm, {f.L1(), f.L2()}); m = m + arg.coeff * f; diff --git a/include/kernels/multi_blas_core.cuh b/include/kernels/multi_blas_core.cuh index 449e1de12b..94558fb7f6 100644 --- a/include/kernels/multi_blas_core.cuh +++ b/include/kernels/multi_blas_core.cuh @@ -89,8 +89,8 @@ namespace quda if (arg.f.read.Y) arg.Y[k].load(y, idx, parity); if (arg.f.read.W) arg.W[k].load(w, idx, parity); } else { - y = ::quda::zero, Arg::n/2>(); - w = ::quda::zero, Arg::n/2>(); + y = {}; + w = {}; } #pragma unroll diff --git a/include/kernels/multi_reduce_core.cuh b/include/kernels/multi_reduce_core.cuh index 9fcadd14c9..63dffc1f31 100644 --- a/include/kernels/multi_reduce_core.cuh +++ b/include/kernels/multi_reduce_core.cuh @@ -10,6 +10,8 @@ namespace quda { + using compute_t = double; + namespace blas { @@ -142,10 +144,10 @@ namespace quda Return the real dot product of x and y */ template - __device__ __host__ void dot_(reduce_t &sum, const complex &a, const complex &b) + __device__ __host__ auto dot_(const complex &a, const complex &b) { - sum += static_cast(a.real()) * static_cast(b.real()); - sum += static_cast(a.imag()) * static_cast(b.imag()); + auto d = reduce_t(a.real()) * reduce_t(b.real()); + return fma(reduce_t(a.imag()), reduce_t(b.imag()), d); } template @@ -159,7 +161,7 @@ namespace quda template __device__ __host__ inline void operator()(reduce_t &sum, T &x, T &y, T &, T &, int, int) const { #pragma unroll - for (int k=0; k < x.size(); k++) dot_(sum, x[k], y[k]); + for (int k=0; k < x.size(); k++) sum = plus::apply(sum, dot_(x[k], y[k])); } constexpr int flops() const { return 2; } //! flops per element @@ -169,13 +171,14 @@ namespace quda Returns complex-valued dot product of x and y */ template - __device__ __host__ void cdot_(reduce_t &sum, const complex &a, const complex &b) + __device__ __host__ auto cdot_(const complex &a, const complex &b) { - using scalar = typename reduce_t::value_type; - sum[0] += static_cast(a.real()) * static_cast(b.real()); - sum[0] += static_cast(a.imag()) * static_cast(b.imag()); - sum[1] += static_cast(a.real()) * static_cast(b.imag()); - sum[1] -= static_cast(a.imag()) * static_cast(b.real()); + using scalar_t = typename reduce_t::value_type; + auto r = scalar_t(a.real()) * scalar_t(b.real()); + r = fma(scalar_t(a.imag()), scalar_t(b.imag()), r); + auto i = scalar_t(a.real()) * scalar_t(b.imag()); + i = fma(-scalar_t(a.imag()) , scalar_t(b.real()), i); + return reduce_t{r, i}; } template @@ -190,7 +193,7 @@ namespace quda template __device__ __host__ inline void operator()(reduce_t &sum, T &x, T &y, T &, T &, int, int) const { #pragma unroll - for (int k=0; k < x.size(); k++) cdot_(sum, x[k], y[k]); + for (int k=0; k < x.size(); k++) sum = plus::apply(sum, cdot_>(x[k], y[k])); } constexpr int flops() const { return 4; } //! flops per element @@ -209,7 +212,7 @@ namespace quda { #pragma unroll for (int k = 0; k < x.size(); k++) { - cdot_(sum, x[k], y[k]); + sum = plus::apply(sum, cdot_(x[k], y[k])); if (i == j) w[k] = y[k]; } } diff --git a/include/kernels/reduce_core.cuh b/include/kernels/reduce_core.cuh index 4597f0b8cd..fb73b79bb3 100644 --- a/include/kernels/reduce_core.cuh +++ b/include/kernels/reduce_core.cuh @@ -2,6 +2,7 @@ #include #include +#include "float_vector.h" #include #include #include @@ -9,6 +10,8 @@ namespace quda { + using compute_t = double; + namespace blas { @@ -104,7 +107,6 @@ namespace quda struct ReduceFunctor { static constexpr use_kernel_arg_p use_kernel_arg = use_kernel_arg_p::TRUE; using reduce_t = reduce_t_; - using reducer = plus; static constexpr bool site_unroll = site_unroll_; //! pre-computation routine called before the "M-loop" @@ -114,12 +116,13 @@ namespace quda __device__ __host__ void post(reduce_t &) const { ; } }; - template - struct Max : public ReduceFunctor { + template + struct Max : public ReduceFunctor { + using reduce_t = compute_t; using reducer = maximum; static constexpr memory_access<1> read{ }; static constexpr memory_access<> write{ }; - Max(const real &, const real &) { ; } + Max(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &max, T &x, T &, T &, T &, T &) const { #pragma unroll @@ -131,18 +134,18 @@ namespace quda constexpr int flops() const { return 0; } //! flops per element }; - template - struct MaxDeviation : public ReduceFunctor> { - using reduce_t = deviation_t; + template + struct MaxDeviation : public ReduceFunctor> { + using reduce_t = deviation_t; using reducer = maximum; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<> write{ }; - MaxDeviation(const real &, const real &) { ; } + MaxDeviation(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &max, T &x, T &y, T &, T &, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { - complex diff = {abs(x[i].real() - y[i].real()), abs(x[i].imag() - y[i].imag())}; + complex diff = {abs(x[i].real() - y[i].real()), abs(x[i].imag() - y[i].imag())}; if (diff.real() > max.diff ) { max.diff = diff.real(); max.ref = abs(y[i].real()); @@ -159,20 +162,21 @@ namespace quda /** Return the L1 norm of x */ - template __device__ __host__ reduce_t norm1_(const complex &a) + template __device__ __host__ auto norm1_(const complex &a) { - return static_cast(sqrt(a.real() * a.real() + a.imag() * a.imag())); + return reduce_t(sqrt(a.real() * a.real() + a.imag() * a.imag())); } template struct Norm1 : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1> read{ }; static constexpr memory_access<> write{ }; - Norm1(const real &, const real &) { ; } + Norm1(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &, T &, T &, T &) const { #pragma unroll - for (int i=0; i < x.size(); i++) sum += norm1_(x[i]); + for (int i=0; i < x.size(); i++) sum = reducer::apply(sum, norm1_(x[i])); } constexpr int flops() const { return 2; } //! flops per element }; @@ -180,21 +184,22 @@ namespace quda /** Return the L2 norm of x */ - template __device__ __host__ void norm2_(reduce_t &sum, const complex &a) + template __device__ __host__ auto norm2_(const complex &a) { - sum += static_cast(a.real()) * static_cast(a.real()); - sum += static_cast(a.imag()) * static_cast(a.imag()); + auto n = reduce_t(a.real()) * reduce_t(a.real()); + return fma(reduce_t(a.imag()), reduce_t(a.imag()), n); } template struct Norm2 : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1> read{ }; static constexpr memory_access<> write{ }; - Norm2(const real &, const real &) { ; } + Norm2(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &, T &, T &, T &) const { #pragma unroll - for (int i = 0; i < x.size(); i++) norm2_(sum, x[i]); + for (int i = 0; i < x.size(); i++) sum = reducer::apply(sum, norm2_(x[i])); } constexpr int flops() const { return 2; } //! flops per element }; @@ -203,21 +208,22 @@ namespace quda Return the real dot product of x and y */ template - __device__ __host__ void dot_(reduce_t &sum, const complex &a, const complex &b) + __device__ __host__ auto dot_(const complex &a, const complex &b) { - sum += static_cast(a.real()) * static_cast(b.real()); - sum += static_cast(a.imag()) * static_cast(b.imag()); + auto d = reduce_t(a.real()) * reduce_t(b.real()); + return fma(reduce_t(a.imag()), reduce_t(b.imag()), d); } template struct Dot : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1,1> read{ }; static constexpr memory_access<> write{ }; - Dot(const real &, const real &) { ; } + Dot(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &, T &, T &) const { #pragma unroll - for (int i = 0; i < x.size(); i++) dot_(sum, x[i], y[i]); + for (int i = 0; i < x.size(); i++) sum = reducer::apply(sum, dot_(x[i], y[i])); } constexpr int flops() const { return 2; } //! flops per element }; @@ -228,17 +234,18 @@ namespace quda */ template struct axpbyzNorm2 : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1, 0> read{ }; static constexpr memory_access<0, 0, 1> write{ }; const real a; const real b; - axpbyzNorm2(const real &a, const real &b) : a(a), b(b) { ; } + axpbyzNorm2(const real_t &a, const real_t &b) : a(a), b(b) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &z, T &, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { z[i] = a * x[i] + b * y[i]; - norm2_(sum, z[i]); + sum = reducer::apply(sum, norm2_(z[i])); } } constexpr int flops() const { return 4; } //! flops per element @@ -250,16 +257,17 @@ namespace quda */ template struct AxpyReDot : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<0, 1> write{ }; const real a; - AxpyReDot(const real &a, const real &) : a(a) { ; } + AxpyReDot(const real_t &a, const real_t &) : a(a) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &, T &, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { y[i] += a * x[i]; - dot_(sum, x[i], y[i]); + sum = reducer::apply(sum, dot_(x[i], y[i])); } } constexpr int flops() const { return 4; } //! flops per element @@ -271,6 +279,7 @@ namespace quda */ template struct caxpyNorm2 : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<0, 1> write{ }; const complex a; @@ -280,20 +289,21 @@ namespace quda #pragma unroll for (int i = 0; i < x.size(); i++) { y[i] = cmac(a, x[i], y[i]); - norm2_(sum, y[i]); + sum = reducer::apply(sum, norm2_(y[i])); } } constexpr int flops() const { return 6; } //! flops per element }; /** - double cabxpyzAxNorm(float a, complex b, float *x, float *y, float *z){} + cabxpyzAxNorm(float a, complex b, float *x, float *y, float *z){} First performs the operation z[i] = y[i] + a*b*x[i] Second performs x[i] *= a Third returns the norm of x */ template struct cabxpyzaxnorm : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1, 0> read{ }; static constexpr memory_access<1, 0, 1> write{ }; const real a; @@ -305,7 +315,7 @@ namespace quda for (int i = 0; i < x.size(); i++) { x[i] *= a; z[i] = cmac(b, x[i], y[i]); - norm2_(sum, z[i]); + sum = reducer::apply(sum, norm2_(z[i])); } } constexpr int flops() const { return 10; } //! flops per element @@ -315,37 +325,42 @@ namespace quda Returns complex-valued dot product of x and y */ template - __device__ __host__ void cdot_(reduce_t &sum, const complex &a, const complex &b) + __device__ __host__ auto cdot_(const complex &a, const complex &b) { using scalar_t = typename reduce_t::value_type; - sum[0] += static_cast(a.real()) * static_cast(b.real()); - sum[0] += static_cast(a.imag()) * static_cast(b.imag()); - sum[1] += static_cast(a.real()) * static_cast(b.imag()); - sum[1] -= static_cast(a.imag()) * static_cast(b.real()); + auto r = scalar_t(a.real()) * scalar_t(b.real()); + r = fma(scalar_t(a.imag()), scalar_t(b.imag()), r); + auto i = scalar_t(a.real()) * scalar_t(b.imag()); + i = fma(-scalar_t(a.imag()) , scalar_t(b.real()), i); + return reduce_t{r, i}; } template struct Cdot : public ReduceFunctor> { using reduce_t = array; + using compute_t = array; + using reducer = plus; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<> write{ }; - Cdot(const complex &, const complex &) { ; } + Cdot(const complex_t &, const complex_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &, T &, T &) const { #pragma unroll - for (int i = 0; i < x.size(); i++) cdot_(sum, x[i], y[i]); + for (int i = 0; i < x.size(); i++) sum = reducer::apply(sum, cdot_(x[i], y[i])); } constexpr int flops() const { return 4; } //! flops per element }; /** - double caxpyDotzyCuda(float a, float *x, float *y, float *z, n){} + caxpyDotzyCuda(float a, float *x, float *y, float *z, n){} First performs the operation y[i] = a*x[i] + y[i] Second returns the dot product (z,y) */ template struct caxpydotzy : public ReduceFunctor> { using reduce_t = array; + using compute_t = array; + using reducer = plus; static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<0, 1> write{ }; const complex a; @@ -355,7 +370,7 @@ namespace quda #pragma unroll for (int i = 0; i < x.size(); i++) { y[i] = cmac(a, x[i], y[i]); - cdot_(sum, z[i], y[i]); + sum = reducer::apply(sum, cdot_(z[i], y[i])); } } constexpr int flops() const { return 8; } //! flops per element @@ -366,25 +381,27 @@ namespace quda Returns the norm of x */ template - __device__ __host__ void cdotNormAB_(reduce_t &sum, const InputType &a, const InputType &b) + __device__ __host__ reduce_t cdotNormAB_(const InputType &a, const InputType &b) { using real = typename InputType::value_type; using scalar = typename reduce_t::value_type; - cdot_(sum, a, b); - norm2_(sum[2], a); - norm2_(sum[3], b); + auto cdot = cdot_(a, b); + return {cdot[0], cdot[1], norm2_(a), norm2_(b)}; } template struct CdotNormAB : public ReduceFunctor> { using reduce_t = array; + using reducer = plus; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<> write{ }; - CdotNormAB(const real &, const real &) { ; } + CdotNormAB(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &, T &, T &) const { #pragma unroll - for (int i = 0; i < x.size(); i++) cdotNormAB_(sum, x[i], y[i]); + for (int i = 0; i < x.size(); i++) { + sum = reducer::apply(sum, cdotNormAB_>(x[i], y[i])); + } } constexpr int flops() const { return 8; } //! flops per element }; @@ -394,12 +411,12 @@ namespace quda Returns the norm of y */ template - __device__ __host__ void cdotNormB_(reduce_t &sum, const InputType &a, const InputType &b) + __device__ __host__ reduce_t cdotNormB_(const InputType &a, const InputType &b) { using real = typename InputType::value_type; using scalar = typename reduce_t::value_type; - cdot_(sum, a, b); - norm2_(sum[2], b); + auto cdot = cdot_(a, b); + return {cdot[0], cdot[1], norm2_(b)}; } /** @@ -409,6 +426,7 @@ namespace quda template struct caxpbypzYmbwcDotProductUYNormY_ : public ReduceFunctor> { using reduce_t = array; + using reducer = plus; static constexpr memory_access<1, 1, 1, 1, 1> read{ }; static constexpr memory_access<0, 1, 1> write{ }; const complex a; @@ -421,7 +439,7 @@ namespace quda y[i] = cmac(a, x[i], y[i]); y[i] = cmac(b, z[i], y[i]); z[i] = cmac(-b, w[i], z[i]); - cdotNormB_(sum, v[i], z[i]); + sum = reducer::apply(sum, cdotNormB_>(v[i], z[i])); } } constexpr int flops() const { return 18; } //! flops per element @@ -436,17 +454,17 @@ namespace quda template struct axpyCGNorm2 : public ReduceFunctor> { using reduce_t = array; + using reducer = plus; static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<0, 1> write{ }; const real a; - axpyCGNorm2(const real &a, const real &) : a(a) { ; } + axpyCGNorm2(const real_t &a, const real_t &) : a(a) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &, T &, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { auto y_new = y[i] + a * x[i]; - norm2_(sum[0], y_new); - dot_(sum[1], y_new, y_new - y[i]); + sum = reducer::apply(sum, array{norm2_(y_new), dot_(y_new, y_new - y[i])}); y[i] = y_new; } } @@ -470,30 +488,21 @@ namespace quda static constexpr memory_access<1, 1> read{ }; static constexpr memory_access<> write{ }; - reduce_t aux; - HeavyQuarkResidualNorm_(const real &, const real &) : aux {} { ; } + array aux; + HeavyQuarkResidualNorm_(const real_t &, const real_t &) : aux {} { ; } - __device__ __host__ void pre() - { - aux[0] = 0; - aux[1] = 0; - } + __device__ __host__ void pre() { aux = {}; } template __device__ __host__ void operator()(reduce_t &, T &x, T &y, T &, T &, T &) { #pragma unroll - for (int i = 0; i < x.size(); i++) { - norm2_(aux[0], x[i]); - norm2_(aux[1], y[i]); - } + for (int i = 0; i < x.size(); i++) aux = plus>::apply(aux, {norm2_(x[i]), norm2_(y[i])}); } //! sum the solution and residual norms, and compute the heavy-quark norm __device__ __host__ void post(reduce_t &sum) { - sum[0] += aux[0]; - sum[1] += aux[1]; - sum[2] += (aux[0] > 0.0) ? (aux[1] / aux[0]) : static_cast(1.0); + sum = reducer::apply(sum, array{aux[0], aux[1], (aux[0] > 0.0) ? (aux[1] / aux[0]) : compute_t(1.0)}); } constexpr int flops() const { return 4; } //! undercounts since it excludes the per-site division @@ -517,87 +526,76 @@ namespace quda static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<> write{ }; - reduce_t aux; - xpyHeavyQuarkResidualNorm_(const real &, const real &) : aux {} { ; } + array aux; + xpyHeavyQuarkResidualNorm_(const real_t &, const real_t &) : aux {} { ; } - __device__ __host__ void pre() - { - aux[0] = 0; - aux[1] = 0; - } + __device__ __host__ void pre() { aux = {}; } template __device__ __host__ void operator()(reduce_t &, T &x, T &y, T &z, T &, T &) { #pragma unroll - for (int i = 0; i < x.size(); i++) { - norm2_(aux[0], x[i] + y[i]); - norm2_(aux[1], z[i]); - } + for (int i = 0; i < x.size(); i++) aux = plus>::apply(aux, {norm2_(x[i] + y[i]), norm2_(z[i])}); } //! sum the solution and residual norms, and compute the heavy-quark norm __device__ __host__ void post(reduce_t &sum) { - sum[0] += aux[0]; - sum[1] += aux[1]; - sum[2] += (aux[0] > 0.0) ? (aux[1] / aux[0]) : static_cast(1.0); + sum = reducer::apply(sum, array{aux[0], aux[1], (aux[0] > 0.0) ? (aux[1] / aux[0]) : compute_t(1.0)}); } constexpr int flops() const { return 5; } }; /** - double3 tripleCGReduction(V x, V y, V z){} + tripleCGReduction(V x, V y, V z){} First performs the operation norm2(x) - Second performs the operatio norm2(y) + Second performs the operation norm2(y) Third performs the operation dotPropduct(y,z) */ template struct tripleCGReduction_ : public ReduceFunctor> { using reduce_t = array; + using reducer = plus; static constexpr memory_access<1, 1, 1> read{ }; static constexpr memory_access<> write{ }; - tripleCGReduction_(const real &, const real &) { ; } + tripleCGReduction_(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &z, T &, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { - norm2_(sum[0], x[i]); - norm2_(sum[1], y[i]); - dot_(sum[2], y[i], z[i]); + sum = reducer::apply(sum, array{norm2_(x[i]), norm2_(y[i]), dot_(y[i], z[i])}); } } constexpr int flops() const { return 6; } //! flops per element }; /** - double4 quadrupleCGReduction(V x, V y, V z){} + quadrupleCGReduction(V x, V y, V z){} First performs the operation norm2(x) - Second performs the operatio norm2(y) + Second performs the operation norm2(y) Third performs the operation dotPropduct(y,z) Fourth performs the operation norm(z) */ template struct quadrupleCGReduction_ : public ReduceFunctor> { using reduce_t = array; + using reducer = plus; static constexpr memory_access<1, 1, 1, 1> read{ }; static constexpr memory_access<> write{ }; - quadrupleCGReduction_(const real &, const real &) { ; } + quadrupleCGReduction_(const real_t &, const real_t &) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &z, T &w, T &) const { #pragma unroll for (int i = 0; i < x.size(); i++) { - norm2_(sum[0], x[i]); - norm2_(sum[1], y[i]); - dot_(sum[2], y[i], z[i]); - norm2_(sum[3], w[i]); + sum = reducer::apply(sum, array{norm2_(x[i]), norm2_(y[i]), + dot_(y[i], z[i]), norm2_(w[i])}); } } constexpr int flops() const { return 8; } //! flops per element }; /** - double quadrupleCG3InitNorm(d a, d b, V x, V y, V z, V w, V v){} + quadrupleCG3InitNorm(d a, d b, V x, V y, V z, V w, V v){} z = x; w = y; x += a*y; @@ -606,10 +604,11 @@ namespace quda */ template struct quadrupleCG3InitNorm_ : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1, 0, 0, 1> read{ }; static constexpr memory_access<1, 1, 1, 1> write{ }; const real a; - quadrupleCG3InitNorm_(const real &a, const real &) : a(a) { ; } + quadrupleCG3InitNorm_(const real_t &a, const real_t &) : a(a) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &z, T &w, T &v) const { #pragma unroll @@ -618,14 +617,14 @@ namespace quda w[i] = y[i]; x[i] += a * y[i]; y[i] -= a * v[i]; - norm2_(sum, y[i]); + sum = reducer::apply(sum, norm2_(y[i])); } } constexpr int flops() const { return 6; } //! flops per element check if it's right }; /** - double quadrupleCG3UpdateNorm(d gamma, d rho, V x, V y, V z, V w, V v){} + quadrupleCG3UpdateNorm(d gamma, d rho, V x, V y, V z, V w, V v){} tmpx = x; tmpy = y; x = b*(x + a*y) + (1-b)*z; @@ -636,11 +635,12 @@ namespace quda */ template struct quadrupleCG3UpdateNorm_ : public ReduceFunctor { + using reducer = plus; static constexpr memory_access<1, 1, 1, 1, 1> read{ }; static constexpr memory_access<1, 1, 1, 1> write{ }; const real a; const real b; - quadrupleCG3UpdateNorm_(const real &a, const real &b) : a(a), b(b) { ; } + quadrupleCG3UpdateNorm_(const real_t &a, const real_t &b) : a(a), b(b) { ; } template __device__ __host__ void operator()(reduce_t &sum, T &x, T &y, T &z, T &w, T &v) const { #pragma unroll @@ -651,7 +651,7 @@ namespace quda y[i] = b * (y[i] - a * v[i]) + ((real)1.0 - b) * w[i]; z[i] = tmpx; w[i] = tmpy; - norm2_(sum, y[i]); + sum = reducer::apply(sum, norm2_(y[i])); } } constexpr int flops() const { return 16; } //! flops per element check if it's right diff --git a/include/kernels/staggered_coarse_op_kernel.cuh b/include/kernels/staggered_coarse_op_kernel.cuh index 7fe913bfdf..1786a23f09 100644 --- a/include/kernels/staggered_coarse_op_kernel.cuh +++ b/include/kernels/staggered_coarse_op_kernel.cuh @@ -34,7 +34,7 @@ namespace quda { const int coarseVolumeCB; /** Coarse grid volume */ real two_mass; /** Two times the staggered mass value */ - CalculateStaggeredYArg(GaugeField &Y, GaugeField &X, const GaugeField &U, double mass) : + CalculateStaggeredYArg(GaugeField &Y, GaugeField &X, const GaugeField &U, real_t mass) : kernel_param(dim3(U.VolumeCB(), fineColor * fineColor, 2)), Y(Y), X(X), diff --git a/include/kernels/staggered_kd_reorder_xinv_kernel.cuh b/include/kernels/staggered_kd_reorder_xinv_kernel.cuh index e8c7dd9d65..727306ef0d 100644 --- a/include/kernels/staggered_kd_reorder_xinv_kernel.cuh +++ b/include/kernels/staggered_kd_reorder_xinv_kernel.cuh @@ -41,7 +41,7 @@ namespace quda { static constexpr int coarse_color = coarseColor; - CalculateStaggeredGeometryReorderArg(GaugeField& fineXinv, const GaugeField& coarseXinv, const double scale) : + CalculateStaggeredGeometryReorderArg(GaugeField& fineXinv, const GaugeField& coarseXinv, const real_t scale) : kernel_param(dim3(fineXinv.VolumeCB(), kdBlockSize, 2)), fineXinv(fineXinv), coarseXinv(coarseXinv), diff --git a/include/kernels/transform_reduce.cuh b/include/kernels/transform_reduce.cuh index af9df8d54a..f6fa6967b5 100644 --- a/include/kernels/transform_reduce.cuh +++ b/include/kernels/transform_reduce.cuh @@ -47,7 +47,7 @@ namespace quda { auto k = arg.m(i); auto v = arg.v[j]; auto t = arg.h(v[k]); - return operator()(t, value); + return operator()(value, t); } }; diff --git a/include/lattice_field.h b/include/lattice_field.h index e7c43b7d69..e745393366 100644 --- a/include/lattice_field.h +++ b/include/lattice_field.h @@ -81,7 +81,7 @@ namespace quda { lat_dim_t r = {}; /** For fixed-point fields that need a global scaling factor */ - double scale = 1.0; + real_t scale = 1.0; /** @brief Default constructor for LatticeFieldParam @@ -217,7 +217,7 @@ namespace quda { mutable bool ghost_precision_reset = false; /** For fixed-point fields that need a global scaling factor */ - double scale = 0.0; + real_t scale = 0.0; /** Whether the field is full or single parity */ QudaSiteSubset siteSubset = QUDA_INVALID_SITE_SUBSET; @@ -666,13 +666,13 @@ namespace quda { /** @return The global scaling factor for a fixed-point field */ - double Scale() const { return scale; } + real_t Scale() const { return scale; } /** @brief Set the scale factor for a fixed-point field @param[in] scale_ The new scale factor */ - void Scale(double scale_) { scale = scale_; } + void Scale(real_t scale_) { scale = scale_; } /** @return Field subset type diff --git a/include/madwf_ml.h b/include/madwf_ml.h index c017229311..8b812564ce 100644 --- a/include/madwf_ml.h +++ b/include/madwf_ml.h @@ -82,7 +82,7 @@ namespace quda MadwfParam param; - double mu; + real_t mu; ColorSpinorField forward_tmp; ColorSpinorField backward_tmp; @@ -106,7 +106,7 @@ namespace quda @param[out] out the output vector @param[in] in the input vector */ - double cost(const DiracMatrix &ref, Solver &base, ColorSpinorField &out, const ColorSpinorField &in); + real_t cost(const DiracMatrix &ref, Solver &base, ColorSpinorField &out, const ColorSpinorField &in); /** @brief Save the current parameter to disk diff --git a/include/madwf_param.h b/include/madwf_param.h index 5046ee6a67..4339e70dd8 100644 --- a/include/madwf_param.h +++ b/include/madwf_param.h @@ -8,7 +8,7 @@ namespace quda struct MadwfParam { /** The diagonal constant to suppress the low modes when performing 5D transfer */ - double madwf_diagonal_suppressor; + real_t madwf_diagonal_suppressor; /** The target MADWF Ls to be used in the accelerator */ int madwf_ls; @@ -17,7 +17,7 @@ namespace quda int madwf_null_miniter; /** The maximum tolerance after which to generate the null vectors for MADWF */ - double madwf_null_tol; + real_t madwf_null_tol; /** The maximum number of iterations for the training iterations */ int madwf_train_maxiter; diff --git a/include/multigrid.h b/include/multigrid.h index 82a46998c4..84dd2b2e99 100644 --- a/include/multigrid.h +++ b/include/multigrid.h @@ -121,7 +121,7 @@ namespace quda { int nu_post; /** Tolerance to use for the solver / smoother (if applicable) */ - double smoother_tol; + real_t smoother_tol; /** Multigrid cycle type */ QudaMultigridCycleType cycle_type; @@ -399,7 +399,7 @@ namespace quda { */ void resetStaggeredKD(GaugeField *gauge_in, GaugeField *fat_gauge_in, GaugeField *long_gauge_in, GaugeField *gauge_sloppy_in, GaugeField *fat_gauge_sloppy_in, - GaugeField *long_gauge_sloppy_in, double mass); + GaugeField *long_gauge_sloppy_in, real_t mass); /** @brief Dump the null-space vectors to disk. Will recurse dumping all levels. @@ -526,7 +526,7 @@ namespace quda { @param[in] halo_precision What precision to use for the halos (if QUDA_INVALID_PRECISION, use field precision) */ void ApplyCoarse(cvector_ref &out, cvector_ref &inA, - cvector_ref &inB, const GaugeField &Y, const GaugeField &X, double kappa, + cvector_ref &inB, const GaugeField &Y, const GaugeField &X, real_t kappa, int parity = QUDA_INVALID_PARITY, bool dslash = true, bool clover = true, bool dagger = false, const int *commDim = 0, QudaPrecision halo_precision = QUDA_INVALID_PRECISION, bool use_mma = false); @@ -551,7 +551,7 @@ namespace quda { */ template void ApplyCoarse(cvector_ref &out, cvector_ref &inA, - cvector_ref &inB, const GaugeField &Y, const GaugeField &X, double kappa, + cvector_ref &inB, const GaugeField &Y, const GaugeField &X, real_t kappa, int parity, bool dslash, bool clover, const int *commDim, QudaPrecision halo_precision); /** @@ -595,7 +595,7 @@ namespace quda { even-odd preconditioned and we coarsen the full operator. */ void CoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, const CloverField *clover, - double kappa, double mass, double mu, double mu_factor, QudaDiracType dirac, QudaMatPCType matpc); + real_t kappa, real_t mass, real_t mu, real_t mu_factor, QudaDiracType dirac, QudaMatPCType matpc); /** @brief Coarse operator construction from a fine-grid operator @@ -617,7 +617,7 @@ namespace quda { */ template void CoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, const CloverField *clover, - double kappa, double mass, double mu, double mu_factor, QudaDiracType dirac, QudaMatPCType matpc); + real_t kappa, real_t mass, real_t mu, real_t mu_factor, QudaDiracType dirac, QudaMatPCType matpc); /** @brief Coarse operator construction from a fine-grid operator (Staggered) @@ -635,12 +635,12 @@ namespace quda { For staggered, should always be QUDA_MATPC_INVALID. */ void StaggeredCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, - const GaugeField &longGauge, const GaugeField &XinvKD, double mass, bool allow_truncation, + const GaugeField &longGauge, const GaugeField &XinvKD, real_t mass, bool allow_truncation, QudaDiracType dirac, QudaMatPCType matpc); template void StaggeredCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, - const GaugeField &longGauge, const GaugeField &XinvKD, double mass, bool allow_truncation, + const GaugeField &longGauge, const GaugeField &XinvKD, real_t mass, bool allow_truncation, QudaDiracType dirac, QudaMatPCType matpc); /** @@ -665,7 +665,7 @@ namespace quda { @param use_mma[in] Whether or not use MMA (tensor core) to do the calculation, default to false */ void CoarseCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, const GaugeField &clover, - const GaugeField &cloverInv, double kappa, double mass, double mu, double mu_factor, + const GaugeField &cloverInv, real_t kappa, real_t mass, real_t mu, real_t mu_factor, QudaDiracType dirac, QudaMatPCType matpc, bool need_bidirectional, bool use_mma = false); /** @@ -695,8 +695,8 @@ namespace quda { */ template void CoarseCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, const GaugeField &gauge, - const GaugeField &clover, const GaugeField &cloverInv, double kappa, double mass, double mu, - double mu_factor, QudaDiracType dirac, QudaMatPCType matpc, bool need_bidirectional); + const GaugeField &clover, const GaugeField &cloverInv, real_t kappa, real_t mass, real_t mu, + real_t mu_factor, QudaDiracType dirac, QudaMatPCType matpc, bool need_bidirectional); /** @brief Calculate preconditioned coarse links and coarse clover inverse field diff --git a/include/pgauge_monte.h b/include/pgauge_monte.h index 33decf4cf0..b9cfabdcd7 100644 --- a/include/pgauge_monte.h +++ b/include/pgauge_monte.h @@ -14,7 +14,7 @@ namespace quda { * @param[in] nhb number of heatbath steps * @param[in] nover number of overrelaxation steps */ - void Monte(GaugeField &data, RNG &rngstate, double Beta, int nhb, int nover); + void Monte(GaugeField &data, RNG &rngstate, real_t Beta, int nhb, int nover); /** * @brief Perform a cold start to the gauge field, identity SU(3) @@ -56,15 +56,15 @@ namespace quda { * @brief Calculate the Determinant * * @param[in] data Gauge field - * @returns double2 complex Determinant value + * @returns complex Determinant value */ - double2 getLinkDeterminant(GaugeField &data); + complex_t getLinkDeterminant(GaugeField &data); /** * @brief Calculate the Trace * * @param[in] data Gauge field - * @returns double2 complex trace value + * @returns complex trace value */ - double2 getLinkTrace(GaugeField &data); + complex_t getLinkTrace(GaugeField &data); } diff --git a/include/polynomial.h b/include/polynomial.h index aa51d372ed..4b4ea87cb8 100644 --- a/include/polynomial.h +++ b/include/polynomial.h @@ -9,68 +9,68 @@ namespace quda { - inline std::vector quadratic_formula(std::array coeff) + inline std::vector quadratic_formula(std::array coeff) { - std::vector z; + std::vector z; z.reserve(2); - double &a = coeff[0]; - double &b = coeff[1]; - double &c = coeff[2]; + real_t &a = coeff[0]; + real_t &b = coeff[1]; + real_t &c = coeff[2]; // a x^2 + b x + c = 0 if (a == 0) { // actually a linear equation if (b != 0) { z.push_back(-c / b); } } else { - double delta = b * b - 4.0 * a * c; + real_t delta = b * b - 4.0 * a * c; if (delta >= 0) { - z.push_back((-b + std::sqrt(delta)) / (2.0 * a)); - z.push_back((-b - std::sqrt(delta)) / (2.0 * a)); + z.push_back((-b + sqrt(delta)) / (2.0 * a)); + z.push_back((-b - sqrt(delta)) / (2.0 * a)); } } return z; } - inline std::vector cubic_formula(std::array coeff) + inline std::vector cubic_formula(std::array coeff) { - std::vector t; + std::vector t; t.reserve(3); // a x^3 + b x^2 + c x + d = 0 - double &a = coeff[0]; - double &b = coeff[1]; - double &c = coeff[2]; - double &d = coeff[3]; + real_t &a = coeff[0]; + real_t &b = coeff[1]; + real_t &c = coeff[2]; + real_t &d = coeff[3]; if (a == 0) { // actually a quadratic equation. - std::array quadratic_coeff = {coeff[1], coeff[2], coeff[3]}; + std::array quadratic_coeff = {coeff[1], coeff[2], coeff[3]}; auto quad = quadratic_formula(quadratic_coeff); for (size_t i = 0; i < quad.size(); i++) { t.push_back(quad[i]); } return t; } - double a2 = a * a; - double a3 = a * a * a; + real_t a2 = a * a; + real_t a3 = a * a * a; - double b2 = b * b; - double b3 = b * b * b; + real_t b2 = b * b; + real_t b3 = b * b * b; - double p = (3.0 * a * c - b2) / (3.0 * a2); - double q = (2.0 * b3 - 9.0 * a * b * c + 27.0 * a2 * d) / (27.0 * a3); + real_t p = (3.0 * a * c - b2) / (3.0 * a2); + real_t q = (2.0 * b3 - 9.0 * a * b * c + 27.0 * a2 * d) / (27.0 * a3); // Now solving t^3 + p t + q = 0 if (p == 0) { - t.push_back(std::cbrt(-q)); + t.push_back(std::cbrt(-double(q))); } else { - double delta = -4.0 * p * p * p - 27.0 * q * q; + real_t delta = -4.0 * p * p * p - 27.0 * q * q; if (delta == 0) { @@ -80,18 +80,18 @@ namespace quda } else if (delta > 0) { - double theta = std::acos(1.5 * (q / p) * std::sqrt(-3.0 / p)); - double tmp = 2.0 * std::sqrt(-p / 3.0); + double theta = std::acos(double(1.5 * (q / p) * sqrt(-3.0 / p))); + real_t tmp = 2.0 * sqrt(-p / 3.0); for (int k = 0; k < 3; k++) { t.push_back(tmp * std::cos((theta - 2.0 * M_PI * k) / 3.0)); } } else if (delta < 0) { if (p < 0) { - double theta = std::acosh(-1.5 * std::abs(q) / p * std::sqrt(-3.0 / p)); - t.push_back(-2.0 * std::abs(q) / q * std::sqrt(-p / 3.0) * cosh(theta / 3.0)); + double theta = std::acosh(-double(1.5 * abs(q) / p * sqrt(-3.0 / p))); + t.push_back(-2.0 * abs(q) / q * sqrt(-p / 3.0) * cosh(theta / 3.0)); } else if (p > 0) { - double theta = std::asinh(+1.5 * q / p * std::sqrt(3.0 / p)); - t.push_back(-2.0 * std::sqrt(p / 3.0) * sinh(theta / 3.0)); + double theta = std::asinh(double(1.5 * q / p * sqrt(3.0 / p))); + t.push_back(-2.0 * sqrt(p / 3.0) * sinh(theta / 3.0)); } } } @@ -101,11 +101,11 @@ namespace quda return t; } - inline double poly4(std::array coeffs, double x) + inline real_t poly4(std::array coeffs, real_t x) { - double x2 = x * x; - double x3 = x * x2; - double x4 = x2 * x2; + real_t x2 = x * x; + real_t x3 = x * x2; + real_t x4 = x2 * x2; return x4 * coeffs[4] + x3 * coeffs[3] + x2 * coeffs[2] + x * coeffs[1] + coeffs[0]; } diff --git a/include/quda_define.h.in b/include/quda_define.h.in index 850c3990df..2dbe1604f7 100644 --- a/include/quda_define.h.in +++ b/include/quda_define.h.in @@ -260,3 +260,10 @@ static_assert(QUDA_ORDER_FP_MG == 2 || QUDA_ORDER_FP_MG == 4 || QUDA_ORDER_FP_MG #if !defined(QUDA_TARGET_CUDA) && !defined(QUDA_TARGET_HIP) && !defined(QUDA_TARGET_SYCL) #error "No QUDA_TARGET selected" #endif + +#cmakedefine QUDA_SCALAR_TYPE @QUDA_SCALAR_TYPE@ +#cmakedefine QUDA_REDUCTION_TYPE @QUDA_REDUCTION_TYPE@ + +#cmakedefine QUDA_REDUCTION_ALGORITHM_NAIVE @QUDA_REDUCTION_ALGORITHM_NAIVE@ +#cmakedefine QUDA_REDUCTION_ALGORITHM_KAHAN @QUDA_REDUCTION_ALGORITHM_KAHAN@ +#cmakedefine QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE @QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE@ diff --git a/include/quda_internal.h b/include/quda_internal.h index dd8a6c8177..7fe85f6390 100644 --- a/include/quda_internal.h +++ b/include/quda_internal.h @@ -50,10 +50,24 @@ #include #include #include "timer.h" +#include "dbldbl.h" namespace quda { - using Complex = std::complex; + /** + Scalar real variable type on the host + */ + using real_t = QUDA_SCALAR_TYPE; + + /** + Scalar complex variable type on the host + */ + using complex_t = std::complex; + + /** + Underlying type to use for reductions + */ + using reduction_t = QUDA_REDUCTION_TYPE; /** Array object type used to storing lattice dimensions diff --git a/include/quda_matrix.h b/include/quda_matrix.h index 823ee78b89..99ba8b308d 100644 --- a/include/quda_matrix.h +++ b/include/quda_matrix.h @@ -425,7 +425,9 @@ namespace quda { } template< template class Mat, class T, int N> - __device__ __host__ inline Mat operator+(const Mat & a, const Mat & b) + __device__ __host__ inline + std::enable_if_t, Matrix> || std::is_same_v, HMatrix>, Mat> + operator+(const Mat & a, const Mat & b) { Mat result; #pragma unroll @@ -435,7 +437,8 @@ namespace quda { template< template class Mat, class T, int N> - __device__ __host__ inline Mat operator+=(Mat & a, const Mat & b) + std::enable_if_t, Matrix> || std::is_same_v, HMatrix>, Mat> + __device__ __host__ inline operator+=(Mat & a, const Mat & b) { #pragma unroll for (int i = 0; i < a.size(); i++) a.data[i] += b.data[i]; @@ -443,7 +446,8 @@ namespace quda { } template< template class Mat, class T, int N> - __device__ __host__ inline Mat operator+=(Mat & a, const T & b) + std::enable_if_t, Matrix> || std::is_same_v, HMatrix>, Mat> + __device__ __host__ inline operator+=(Mat & a, const T & b) { #pragma unroll for (int i = 0; i < a.rows(); i++) a(i, i) += b; @@ -451,7 +455,8 @@ namespace quda { } template< template class Mat, class T, int N> - __device__ __host__ inline Mat operator-=(Mat & a, const Mat & b) + std::enable_if_t, Matrix> || std::is_same_v, HMatrix>, Mat> + __device__ __host__ inline operator-=(Mat & a, const Mat & b) { #pragma unroll for (int i = 0; i < a.size(); i++) a.data[i] -= b.data[i]; diff --git a/include/reducer.h b/include/reducer.h index 3c2f37f364..b510d30b33 100644 --- a/include/reducer.h +++ b/include/reducer.h @@ -3,8 +3,8 @@ #include "complex_quda.h" #include "quda_constants.h" #include "quda_api.h" -#include -#include +#include "math_helper.cuh" +#include "float_vector.h" #include "comm_quda.h" #include "fast_intdiv.h" @@ -23,12 +23,6 @@ namespace quda { -#ifdef QUAD_SUM - using device_reduce_t = doubledouble; -#else - using device_reduce_t = double; -#endif - namespace reducer { /** @@ -79,15 +73,77 @@ namespace quda /** plus reducer, used for conventional sum reductions */ - template struct plus { + template struct plus { static constexpr bool do_sum = true; using reduce_t = T; using reducer_t = plus; template static inline void comm_reduce(std::vector &a) { comm_allreduce_sum(a); } - __device__ __host__ static inline T init() { return zero(); } + __device__ __host__ static inline T init() { return T{}; } __device__ __host__ static inline T apply(T a, T b) { return a + b; } __device__ __host__ inline T operator()(T a, T b) const { return apply(a, b); } + + template + __device__ __host__ static inline std::enable_if_t>, T> apply(T a, const U &b) + { +#pragma unroll + for (int i = 0; i < T::N; i++) a[i].operator+=(b[i]); + return a; + } + template + __device__ __host__ inline std::enable_if_t>, T> operator()(T a, const U &b) const { return apply(a, b); } + }; + +#ifdef QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE + /** + plus reducer, specialized for reproducible sum reductions + */ + template + struct plus>>> { + static constexpr bool do_sum = true; + using reduce_t = T; + using reducer_t = plus; + template static inline void comm_reduce(std::vector &a) { comm_allreduce_sum(a); } + __device__ __host__ static inline T init() { return reduce_t(); } + // rfa_t + rfa_t + // FIXME - should we use references here? + __device__ __host__ static inline T apply(T a, T b) { a.operator+=(b); return a; } + __device__ __host__ inline T operator()(T a, T b) const { return apply(a, b); } + + __device__ __host__ static inline T apply(T a, typename T::ftype b) { a.operator+=(b); return a; } + __device__ __host__ inline T operator()(T a, typename T::ftype b) const { return apply(a, b); } + }; + + /** + plus reducer, specialized for arrays of reproducible sum reductions + */ + template + struct plus, T::N>>>> + { + static constexpr bool do_sum = true; + using reduce_t = T; + using reducer_t = plus; + template static inline void comm_reduce(std::vector &a) { comm_allreduce_sum(a); } + __device__ __host__ static inline T init() { return reduce_t{}; } + // rfa_t + rfa_t + // FIXME - should we use references here? + __device__ __host__ static inline T apply(T a, T b) + { +#pragma unroll + for (int i = 0; i < T::N; i++) a[i].operator+=(b[i]); + return a; + } + __device__ __host__ inline T operator()(T a, T b) const { return apply(a, b); } + + __device__ __host__ static inline T apply(T a, const array &b) + { +#pragma unroll + for (int i = 0; i < T::N; i++) a[i].operator+=(b[i]); + return a; + } + + __device__ __host__ inline T operator()(T a, const array &b) const { return apply(a, b); } }; +#endif /** maximum reducer, used for max reductions diff --git a/include/register_traits.h b/include/register_traits.h index ecb618b9f5..1e255d1413 100644 --- a/include/register_traits.h +++ b/include/register_traits.h @@ -180,7 +180,7 @@ namespace quda { static const int value = 1; }; - template <> struct vec_length { + template <> struct vec_length { static const int value = 2; }; template <> struct vec_length> { @@ -319,6 +319,19 @@ namespace quda { typedef char8 type; }; + // demote vector type to underlying scalar type + template struct get_scalar; + template <> struct get_scalar { using type = float; }; + template <> struct get_scalar { using type = double; }; + template <> struct get_scalar { using type = double; }; + template <> struct get_scalar { using type = doubledouble; }; + template <> struct get_scalar { using type = doubledouble; }; + template <> struct get_scalar> { using type = float; }; + template <> struct get_scalar> { using type = double; }; + template struct get_scalar> { using type = typename get_scalar::type; }; + + template using get_scalar_t = typename get_scalar::type; + template struct AllocType { }; template<> struct AllocType { typedef size_t type; }; template<> struct AllocType { typedef int type; }; diff --git a/include/reliable_updates.h b/include/reliable_updates.h index a1ddd9f122..7940e88bfb 100644 --- a/include/reliable_updates.h +++ b/include/reliable_updates.h @@ -34,44 +34,44 @@ namespace quda bool alternative_reliable; - double u; - double uhigh; - double Anorm; - double delta; + real_t u; + real_t uhigh; + real_t Anorm; + real_t delta; // this parameter determines how many consective reliable update // residual increases we tolerate before terminating the solver, // i.e., how long do we want to keep trying to converge - int maxResIncrease; // check if we reached the limit of our tolerance - int maxResIncreaseTotal; + int maxResIncrease = 0; // check if we reached the limit of our tolerance + int maxResIncreaseTotal = 0; bool use_heavy_quark_res; - int hqmaxresIncrease; - int hqmaxresRestartTotal; // this limits the number of heavy quark restarts we can do + int hqmaxresIncrease = 0; + int hqmaxresRestartTotal = 0; // this limits the number of heavy quark restarts we can do }; struct ReliableUpdates { const ReliableUpdatesParams params; - const double deps; - static constexpr double dfac = 1.1; - double d_new = 0; - double d = 0; - double dinit = 0; - double xNorm = 0; - double xnorm = 0; - double pnorm = 0; - double ppnorm = 0; - double delta; - double beta = 0.0; - - double rNorm; - double r0Norm; - double maxrx; - double maxrr; - double maxr_deflate; // The maximum residual since the last deflation + const real_t deps; + static constexpr real_t dfac = 1.1; + real_t d_new = 0; + real_t d = 0; + real_t dinit = 0; + real_t xNorm = 0; + real_t xnorm = 0; + real_t pnorm = 0; + real_t ppnorm = 0; + real_t delta; + real_t beta = 0.0; + + real_t rNorm; + real_t r0Norm; + real_t maxrx; + real_t maxrr; + real_t maxr_deflate; // The maximum residual since the last deflation int resIncrease = 0; int resIncreaseTotal = 0; @@ -90,7 +90,7 @@ namespace quda @param params the parameters @param r2 the residual norm squared */ - ReliableUpdates(ReliableUpdatesParams params, double r2) : + ReliableUpdates(ReliableUpdatesParams params, real_t r2) : params(params), deps(sqrt(params.u)), delta(params.delta), @@ -110,23 +110,23 @@ namespace quda /** @brief Update the norm squared for p (thus ppnorm) */ - void update_ppnorm(double ppnorm_) { ppnorm = ppnorm_; } + void update_ppnorm(real_t ppnorm_) { ppnorm = ppnorm_; } /** @brief Update the norm for r (thus rNorm) */ - void update_rNorm(double rNorm_) { rNorm = rNorm_; } + void update_rNorm(real_t rNorm_) { rNorm = rNorm_; } /** @brief Update maxr_deflate */ - void update_maxr_deflate(double r2) { maxr_deflate = sqrt(r2); } + void update_maxr_deflate(real_t r2) { maxr_deflate = sqrt(r2); } /** @brief Evaluate whether a reliable update is needed @param r2_old the old residual norm squared */ - void evaluate(double r2_old) + void evaluate(real_t r2_old) { if (params.alternative_reliable) { // alternative reliable updates @@ -155,7 +155,7 @@ namespace quda @brief Accumulate the estimate for error - used when reliable update is not performed @param alpha the alpha that is used in CG to update the solution vector x, given p */ - void accumulate_norm(double alpha) + void accumulate_norm(real_t alpha) { // accumulate norms if (params.alternative_reliable) { @@ -164,7 +164,7 @@ namespace quda xnorm = sqrt(pnorm); d_new = d + params.u * rNorm + params.uhigh * params.Anorm * xnorm; if (steps_since_reliable == 0 && getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("New dnew: %e (r %e , y %e)\n", d_new, params.u * rNorm, params.uhigh * params.Anorm * xnorm); + printfQuda("New dnew: %e (r %e , y %e)\n", double(d_new), double(params.u * rNorm), double(params.uhigh * params.Anorm * xnorm)); } steps_since_reliable++; } @@ -174,18 +174,18 @@ namespace quda @param r2 the residual norm squared @param y2 the solution vector norm squared */ - void update_norm(double r2, ColorSpinorField &y) + void update_norm(real_t r2, ColorSpinorField &y) { // update_norms if (params.alternative_reliable) { - double y2 = blas::norm2(y); + real_t y2 = blas::norm2(y); dinit = params.uhigh * (sqrt(r2) + params.Anorm * sqrt(y2)); d = d_new; xnorm = 0; // sqrt(norm2(x)); pnorm = 0; // pnorm + alpha * sqrt(norm2(p)); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) - printfQuda("New dinit: %e (r %e , y %e)\n", dinit, params.uhigh * sqrt(r2), - params.uhigh * params.Anorm * sqrt(y2)); + printfQuda("New dinit: %e (r %e , y %e)\n", double(dinit), double(params.uhigh * sqrt(r2)), + double(params.uhigh * params.Anorm * sqrt(y2))); d_new = dinit; } else { rNorm = sqrt(r2); @@ -201,20 +201,20 @@ namespace quda @param[in/out] L2breakdown whether or not L2 breakdown @param L2breakdown_eps L2 breakdown epsilon */ - bool reliable_break(double r2, double stop, bool &L2breakdown, double L2breakdown_eps) + bool reliable_break(real_t r2, real_t stop, bool &L2breakdown, real_t L2breakdown_eps) { // break-out check if we have reached the limit of the precision if (sqrt(r2) > r0Norm && updateX and not L2breakdown) { // reuse r0Norm for this resIncrease++; resIncreaseTotal++; warningQuda("new reliable residual norm %e is greater than previous reliable residual norm %e (total #inc %i)", - sqrt(r2), r0Norm, resIncreaseTotal); + double(sqrt(r2)), double(r0Norm), resIncreaseTotal); if ((params.use_heavy_quark_res and sqrt(r2) < L2breakdown_eps) or resIncrease > params.maxResIncrease or resIncreaseTotal > params.maxResIncreaseTotal or r2 < stop) { if (params.use_heavy_quark_res) { L2breakdown = true; - warningQuda("L2 breakdown %e, %e", sqrt(r2), L2breakdown_eps); + warningQuda("L2 breakdown %e, %e", double(sqrt(r2)), double(L2breakdown_eps)); } else { if (resIncrease > params.maxResIncrease or resIncreaseTotal > params.maxResIncreaseTotal or r2 < stop) { warningQuda("solver exiting due to too many true residual norm increases"); @@ -235,7 +235,7 @@ namespace quda @param heavy_quark_res_old the old heavy quark residual @param[in/out] heavy_quark_restart whether should restart the heavy quark */ - bool reliable_heavy_quark_break(bool L2breakdown, double heavy_quark_res, double heavy_quark_res_old, + bool reliable_heavy_quark_break(bool L2breakdown, real_t heavy_quark_res, real_t heavy_quark_res_old, bool &heavy_quark_restart) { if (params.use_heavy_quark_res and L2breakdown) { @@ -248,7 +248,7 @@ namespace quda if (heavy_quark_res > heavy_quark_res_old) { // check if new hq residual is greater than previous hqresIncrease++; // count the number of consecutive increases warningQuda("CG: new reliable HQ residual norm %e is greater than previous reliable residual norm %e", - heavy_quark_res, heavy_quark_res_old); + double(heavy_quark_res), double(heavy_quark_res_old)); // break out if we do not improve here anymore if (hqresIncrease > params.hqmaxresIncrease) { warningQuda("CG: solver exiting due to too many heavy quark residual norm increases (%i/%i)", hqresIncrease, @@ -272,7 +272,7 @@ namespace quda @brief Reset the counters - after a reliable update has been performed @param r2 residual norm squared */ - void reset(double r2) + void reset(real_t r2) { steps_since_reliable = 0; r0Norm = sqrt(r2); diff --git a/include/staggered_kd_build_xinv.h b/include/staggered_kd_build_xinv.h index 2bd1b4f600..dd5622b636 100644 --- a/include/staggered_kd_build_xinv.h +++ b/include/staggered_kd_build_xinv.h @@ -14,7 +14,7 @@ namespace quda @param mass [in] Mass of the original staggered operator w/out factor of 2 convention @param dagger_approximation[in] Whether or not to use the dagger approximation, using the dagger of X instead of Xinv */ - void BuildStaggeredKahlerDiracInverse(GaugeField &Xinv, const GaugeField &gauge, const double mass, + void BuildStaggeredKahlerDiracInverse(GaugeField &Xinv, const GaugeField &gauge, const real_t mass, const bool dagger_approximation); /** @@ -25,7 +25,7 @@ namespace quda @param msas [in] Mass of the original staggered operator w/out factor of 2 convention, needed for dagger approx */ void ReorderStaggeredKahlerDiracInverse(GaugeField &xInvFineLayout, const GaugeField &xInvCoarseLayout, - const bool dagger_approximation, const double mass); + const bool dagger_approximation, const real_t mass); /** @brief Allocate and build the Kahler-Dirac inverse block for KD operators @@ -34,7 +34,7 @@ namespace quda @param dagger_approximation[in] Whether or not to use the dagger approximation, using the dagger of X instead of Xinv @return constructed Xinv */ - std::shared_ptr AllocateAndBuildStaggeredKahlerDiracInverse(const GaugeField &gauge, const double mass, + std::shared_ptr AllocateAndBuildStaggeredKahlerDiracInverse(const GaugeField &gauge, const real_t mass, const bool dagger_approximation); } // namespace quda diff --git a/include/targets/cuda/block_reduce_helper.h b/include/targets/cuda/block_reduce_helper.h index dd368b6d02..bdf3988694 100644 --- a/include/targets/cuda/block_reduce_helper.h +++ b/include/targets/cuda/block_reduce_helper.h @@ -11,7 +11,8 @@ */ // whether to use cooperative groups (or cub) -#if !(_NVHPC_CUDA || (defined(__clang__) && defined(__CUDA__))) // neither nvc++ or clang-cuda yet support CG +#if !(_NVHPC_CUDA || (defined(__clang__) && defined(__CUDA__))) && !defined(QUDA_REDUCTION_ALGORITHM_REPRODUCIBLE) // neither nvc++ or clang-cuda yet support CG +// FIXME add CG support for reproducible reductions (can't naively break up the operation) #define USE_CG #endif @@ -44,37 +45,32 @@ namespace quda */ template struct atomic_type; - template <> struct atomic_type { - using type = device_reduce_t; + template <> struct atomic_type { + using type = double; // must break up 128-bit types into doubles }; - template <> struct atomic_type { - using type = float; - }; - - template struct atomic_type>>> { - using type = device_reduce_t; + template <> struct atomic_type { + using type = double; }; - template - struct atomic_type, T::N>>>> { - using type = device_reduce_t; + template <> struct atomic_type { + using type = float; }; - template struct atomic_type, T::N>>>> { - using type = double; + template struct atomic_type>>> { + using type = typename atomic_type::type; }; - template struct atomic_type, T::N>>>> { - using type = float; + template struct atomic_type::value>> { + using type = typename T::ftype; }; - template struct atomic_type>>> { - using type = float; + template struct atomic_type>>> { + using type = typename atomic_type::type; }; - template struct atomic_type>>> { - using type = double; + template struct atomic_type>>> { + using type = typename atomic_type::type; }; // pre-declaration of warp_reduce that we wish to specialize diff --git a/include/targets/cuda/math_helper.cuh b/include/targets/cuda/math_helper.cuh index bdc333297a..cee198e2d0 100644 --- a/include/targets/cuda/math_helper.cuh +++ b/include/targets/cuda/math_helper.cuh @@ -3,7 +3,11 @@ #include #include -#if (CUDA_VERSION >= 11070) && !defined(_NVHPC_CUDA) +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) || (defined(__clang__) && defined(__CUDA__)) +#define QUDA_CUDA_CC +#endif + +#if (CUDA_VERSION >= 11070) && defined(QUDA_CUDA_CC) && !defined(_NVHPC_CUDA) #define BUILTIN_ASSUME(x) \ bool p = x; \ __builtin_assume(p); @@ -34,6 +38,7 @@ namespace quda { template inline void operator()(const T& a, T *s, T *c) { ::sincos(a, s, c); } }; +#ifdef QUDA_CUDA_CC template <> struct sincos_impl { template __device__ inline void operator()(const T& a, T *s, T *c) { @@ -41,6 +46,7 @@ namespace quda { sincos(a, s, c); } }; +#endif /** * @brief Combined sin and cos calculation in QUDA NAMESPACE @@ -55,9 +61,11 @@ namespace quda { inline void operator()(const float& a, float *s, float *c) { ::sincosf(a, s, c); } }; +#ifdef QUDA_CUDA_CC template <> struct sincosf_impl { __device__ inline void operator()(const float& a, float *s, float *c) { __sincosf(a, s, c); } }; +#endif /** * @brief Combined sin and cos calculation in QUDA NAMESPACE @@ -75,10 +83,11 @@ namespace quda { template inline void operator()(const T& a, T *s, T *c) { ::sincos(a * static_cast(M_PI), s, c); } }; +#ifdef QUDA_CUDA_CC template <> struct sincospi_impl { template __device__ inline void operator()(const T& a, T *s, T *c) { sincospi(a, s, c); } }; - +#endif /** * @brief Combined sinpi and cospi calculation in QUDA NAMESPACE @@ -100,20 +109,25 @@ namespace quda { template<> inline __host__ __device__ void sincospi(const float& a, float *s, float *c) { quda::sincos(a * static_cast(M_PI), s, c); } + template struct sinpi_impl { + template inline T operator()(T a) { return ::sin(a * static_cast(M_PI)); } + }; +#ifdef QUDA_CUDA_CC + template <> struct sinpi_impl { template __device__ inline T operator()(T a) { return ::sinpi(a); } }; +#endif /** * @brief Sine pi calculation in QUDA NAMESPACE * @param a the angle * @return result of the sin(a * pi) */ - template inline __host__ __device__ T sinpi(T a) { return ::sinpi(a); } + template inline __host__ __device__ T sinpi(T a) { return target::dispatch(a); } + -#ifndef _NVHPC_CUDA - template struct sinpif_impl { inline float operator()(float a) { return ::sinpif(a); } }; -#else template struct sinpif_impl { inline float operator()(float a) { return ::sinf(a * static_cast(M_PI)); } }; -#endif +#ifdef QUDA_CUDA_CC template <> struct sinpif_impl { __device__ inline float operator()(float a) { return __sinf(a * static_cast(M_PI)); } }; +#endif /** * @brief Sine pi calculation in QUDA NAMESPACE. @@ -124,20 +138,27 @@ namespace quda { */ template<> inline __host__ __device__ float sinpi(float a) { return target::dispatch(a); } + template struct cospi_impl { + template inline T operator()(T a) { return ::cos(a * static_cast(M_PI)); } + }; +#ifdef QUDA_CUDA_CC + template <> struct cospi_impl { + template __device__ inline T operator()(T a) { return ::cospi(a); } + }; +#endif /** * @brief Cosine pi calculation in QUDA NAMESPACE * @param a the angle * @return result of the cos(a * pi) */ - template inline __host__ __device__ T cospi(T a) { return ::cospi(a); } + template inline __host__ __device__ T cospi(T a) { return target::dispatch(a); } + -#ifndef _NVHPC_CUDA - template struct cospif_impl { inline float operator()(float a) { return ::cospif(a); } }; -#else template struct cospif_impl { inline float operator()(float a) { return ::cosf(a * static_cast(M_PI)); } }; -#endif +#ifdef QUDA_CUDA_CC template <> struct cospif_impl { __device__ inline float operator()(float a) { return __cosf(a * static_cast(M_PI)); } }; +#endif /** * @brief Cosine pi calculation in QUDA NAMESPACE. @@ -153,9 +174,11 @@ namespace quda { template inline T operator()(T a) { return static_cast(1.0) / sqrt(a); } }; +#ifdef QUDA_CUDA_CC template <> struct rsqrt_impl { template __device__ inline T operator()(T a) { return ::rsqrt(a); } }; +#endif /** * @brief Reciprocal square root function (rsqrt) @@ -169,6 +192,7 @@ namespace quda { template struct fpow_impl { template inline real operator()(real a, int b) { return std::pow(a, b); } }; +#ifdef QUDA_CUDA_CC template <> struct fpow_impl { __device__ inline double operator()(double a, int b) { return ::pow(a, b); } @@ -179,6 +203,7 @@ namespace quda { return b & 1 ? sign * power : power; } }; +#endif /* @brief Fast power function that works for negative "a" argument @@ -189,11 +214,60 @@ namespace quda { template __device__ __host__ inline real fpow(real a, int b) { return target::dispatch(a, b); } template struct fdividef_impl { inline float operator()(float a, float b) { return a / b; } }; +#ifdef QUDA_CUDA_CC template <> struct fdividef_impl { __device__ inline float operator()(float a, float b) { return __fdividef(a, b); } }; +#endif /** @brief Optimized division routine on the device */ __device__ __host__ inline float fdividef(float a, float b) { return target::dispatch(a, b); } + + template struct dmul_rn_impl { + inline double operator()(double a, double b) { return a * b; } + }; +#ifdef QUDA_CUDA_CC + template <> struct dmul_rn_impl { + __device__ inline double operator()(double a, double b) { return ::__dmul_rn(a, b); } + }; +#endif + + /** + @brief IEEE double precision multiplication + */ + __device__ __host__ inline double dmul_rn(double a, double b) { return target::dispatch(a, b); } + + + template struct dadd_rn_impl { + inline double operator()(double a, double b) { return a + b; } + }; +#ifdef QUDA_CUDA_CC + template <> struct dadd_rn_impl { + __device__ inline double operator()(double a, double b) { return ::__dadd_rn(a, b); } + }; +#endif + + /** + @brief IEEE double precision addition + */ + __device__ __host__ inline double dadd_rn(double a, double b) { return target::dispatch(a, b); } + + + template struct fma_rn_impl { + inline double operator()(double a, double b, double c) { return std::fma(a, b, c); } + }; +#ifdef QUDA_CUDA_CC + template <> struct fma_rn_impl { + __device__ inline double operator()(double a, double b, double c) { return ::__fma_rn(a, b, c); } + }; +#endif + + /** + @brief IEEE double precision fused multiply add + */ + __device__ __host__ inline double fma_rn(double a, double b, double c) { return target::dispatch(a, b, c); } + } + +#undef QUDA_CUDA_CC diff --git a/include/targets/cuda/reduce_helper.h b/include/targets/cuda/reduce_helper.h index 73fc0bfbc0..794999cdd3 100644 --- a/include/targets/cuda/reduce_helper.h +++ b/include/targets/cuda/reduce_helper.h @@ -85,6 +85,14 @@ namespace quda result_h = static_cast(reducer::get_host_buffer()); count = reducer::get_count(); + if constexpr (is_rfa>::value) { + static bool init = false; + if (!init) { + cudaMemcpyToSymbol(reproducible::bin_device_buffer, static_cast(&reducer::get_rfa_bins()), + sizeof(reproducible::RFA_bins), 0, cudaMemcpyHostToDevice); + } + } + if (!commAsyncReduction()) { // initialize the result buffer so we can test for completion for (int i = 0; i < n_reduce * n_item; i++) { @@ -120,33 +128,30 @@ namespace quda @param[out] result The reduction result is copied here @param[in] stream The stream on which we the reduction is being done */ - template - void complete(std::vector &result, const qudaStream_t = device::get_default_stream()) + auto complete(const qudaStream_t = device::get_default_stream()) { - if (launch_error == QUDA_ERROR) return; // kernel launch failed so return - if (launch_error == QUDA_ERROR_UNINITIALIZED) errorQuda("No reduction kernel appears to have been launched"); + std::vector result(n_reduce); if (consumed) errorQuda("Cannot call complete more than once for each construction"); + if (launch_error == QUDA_ERROR_UNINITIALIZED) errorQuda("No reduction kernel appears to have been launched"); + if (launch_error != QUDA_ERROR) { + for (int i = 0; i < n_reduce * n_item; i++) { + while (result_h[i].load(cuda::std::memory_order_relaxed) == init_value()) { } + } - for (int i = 0; i < n_reduce * n_item; i++) { - while (result_h[i].load(cuda::std::memory_order_relaxed) == init_value()) { } - } - - // copy back result element by element and convert if necessary to host reduce type - // unit size here may differ from system_atomic_t size, e.g., if doing double-double - const int n_element = n_reduce * sizeof(T) / sizeof(device_t); - if (result.size() != (unsigned)n_element) - errorQuda("result vector length %lu does not match n_reduce %d", result.size(), n_element); - for (int i = 0; i < n_element; i++) result[i] = reinterpret_cast(result_h)[i]; + memcpy(reinterpret_cast(result.data()), reinterpret_cast(result_h), n_reduce * sizeof(T)); - if (!reset) { - consumed = true; - } else { - // reset the atomic counter - this allows multiple calls to complete with ReduceArg construction - for (int i = 0; i < n_reduce * n_item; i++) { - result_h[i].store(init_value(), cuda::std::memory_order_relaxed); + if (!reset) { + consumed = true; + } else { + // reset the atomic counter - this allows multiple calls to complete with ReduceArg construction + for (int i = 0; i < n_reduce * n_item; i++) { + result_h[i].store(init_value(), cuda::std::memory_order_relaxed); + } + cuda::std::atomic_thread_fence(cuda::std::memory_order_release); } - cuda::std::atomic_thread_fence(cuda::std::memory_order_release); } + + return result; } }; @@ -174,21 +179,32 @@ namespace quda constexpr bool coalesced_write = true; #endif if constexpr (coalesced_write) { - static_assert(n <= device::warp_size(), "reduction array is greater than warp size"); - auto mask = __ballot_sync(0xffffffff, tid < n); - if (tid < n) { + if (tid < device::warp_size()) { // only first warp takes part in write + atomic_t sum_tmp[n]; memcpy(sum_tmp, &sum, sizeof(sum)); - atomic_t s = sum_tmp[0]; + constexpr auto m = (n + device::warp_size() - 1) / device::warp_size(); #pragma unroll - for (int i = 1; i < n; i++) { - auto si = __shfl_sync(mask, sum_tmp[i], 0); - if (i == tid) s = si; - } + for (auto j = 0; j < m; j++) { + + auto t = j * device::warp_size() + tid; // effective thread index + atomic_t s = sum_tmp[j * device::warp_size()]; + auto mask = __ballot_sync(0xffffffff, t < n); - s = (s == init_value()) ? terminate_value() : s; - arg.result_d[n * idx + tid].store(s, cuda::std::memory_order_relaxed); + if (t < n) { +#pragma unroll + for (auto i = 1; i < device::warp_size(); i++) { + if (j * device::warp_size() + i < n) { + auto si = __shfl_sync(mask, sum_tmp[j * device::warp_size() + i], 0); + if (i == tid) s = si; // j * device::warp_size() cancels out on both sides + } + } + + s = (s == init_value()) ? terminate_value() : s; + arg.result_d[n * idx + t].store(s, cuda::std::memory_order_relaxed); + } + } } } else { // write out the final reduced value diff --git a/include/targets/cuda/reduction_kernel.h b/include/targets/cuda/reduction_kernel.h index b2d23e7897..274675b009 100644 --- a/include/targets/cuda/reduction_kernel.h +++ b/include/targets/cuda/reduction_kernel.h @@ -23,13 +23,12 @@ namespace quda template