From a5d3463c215abb0fe2167364490b6aceda4d69b4 Mon Sep 17 00:00:00 2001 From: James Osborn Date: Wed, 4 Oct 2023 15:49:06 -0500 Subject: [PATCH] make special ops initialization mandatory at construction time --- include/dslash_helper.cuh | 13 +- include/kernels/block_orthogonalize.cuh | 4 +- include/kernels/block_transpose.cuh | 4 +- include/kernels/clover_deriv.cuh | 3 +- include/kernels/coarse_op_kernel.cuh | 8 +- include/kernels/color_spinor_pack.cuh | 10 +- include/kernels/covDev.cuh | 3 +- include/kernels/dslash_clover_helper.cuh | 12 +- include/kernels/dslash_coarse.cuh | 4 +- include/kernels/dslash_domain_wall_4d.cuh | 3 +- .../dslash_domain_wall_4d_fused_m5.cuh | 4 +- include/kernels/dslash_domain_wall_5d.cuh | 3 +- include/kernels/dslash_domain_wall_m5.cuh | 8 +- include/kernels/dslash_mobius_eofa.cuh | 18 ++- .../kernels/dslash_ndeg_twisted_clover.cuh | 8 +- ...ash_ndeg_twisted_clover_preconditioned.cuh | 4 +- include/kernels/dslash_ndeg_twisted_mass.cuh | 3 +- ...slash_ndeg_twisted_mass_preconditioned.cuh | 4 +- include/kernels/dslash_staggered.cuh | 3 +- .../dslash_twisted_clover_preconditioned.cuh | 3 +- include/kernels/dslash_twisted_mass.cuh | 3 +- .../dslash_twisted_mass_preconditioned.cuh | 3 +- include/kernels/dslash_wilson.cuh | 3 +- include/kernels/dslash_wilson_clover.cuh | 3 +- .../dslash_wilson_clover_hasenbusch_twist.cuh | 3 +- ...clover_hasenbusch_twist_preconditioned.cuh | 3 +- .../dslash_wilson_clover_preconditioned.cuh | 3 +- include/kernels/field_strength_tensor.cuh | 3 +- include/kernels/gauge_ape.cuh | 6 +- include/kernels/gauge_fix_ovr.cuh | 7 +- include/kernels/gauge_force.cuh | 6 +- include/kernels/gauge_loop_trace.cuh | 7 +- include/kernels/gauge_stout.cuh | 7 +- include/kernels/gauge_wilson_flow.cuh | 4 +- include/kernels/hisq_paths_force.cuh | 8 +- include/kernels/laplace.cuh | 3 +- include/kernels/madwf_transfer.cuh | 4 +- include/kernels/restrictor.cuh | 4 +- include/targets/generic/helpers.h | 6 + include/targets/generic/special_ops.h | 33 ++--- include/targets/sycl/block_reduce_helper.h | 29 +++-- include/targets/sycl/block_reduction_kernel.h | 10 ++ include/targets/sycl/kernel.h | 20 +++ include/targets/sycl/reduce_helper.h | 8 +- include/targets/sycl/reduction_kernel.h | 14 ++ include/targets/sycl/shared_memory_helper.h | 2 + include/targets/sycl/special_ops_target.h | 120 +++++++++++------- include/targets/sycl/tunable_kernel.h | 10 ++ 48 files changed, 321 insertions(+), 133 deletions(-) diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 8a6261699d..465d687979 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -662,15 +662,18 @@ namespace quda static constexpr bool dagger = Arg::dagger; static constexpr KernelType kernel_type = Arg::kernel_type; static constexpr const char *filename() { return Arg::D::filename(); } - constexpr dslash_functor(const Arg &arg) : arg(arg.arg) { } + using typename getSpecialOps::KernelOpsT; + template + constexpr dslash_functor(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg.arg) { } template __forceinline__ __device__ void operator()(int, int s, int parity, bool active = true) { - typename Arg::D dslash(arg); - if constexpr (hasSpecialOps) { - dslash.setSpecialOps(*this); - } + //typename Arg::D dslash(arg); + //if constexpr (hasSpecialOps) { + //dslash.setSpecialOps(*this); + //} + typename Arg::D dslash(*this); // for full fields set parity from z thread index else use arg setting if (nParity == 1) parity = arg.parity; diff --git a/include/kernels/block_orthogonalize.cuh b/include/kernels/block_orthogonalize.cuh index 69842d3084..4218bf0950 100644 --- a/include/kernels/block_orthogonalize.cuh +++ b/include/kernels/block_orthogonalize.cuh @@ -114,7 +114,9 @@ namespace quda { using dot_t = typename BlockOrtho_Params::dot_t; using real = typename Arg::real; - constexpr BlockOrtho_(const Arg &arg) : arg(arg) {} + using typename BlockOrtho_Params::Ops::KernelOpsT; + template + constexpr BlockOrtho_(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void load(ColorSpinor &v, int parity, int x_cb, int chirality, int i) diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 42fe9f7080..e227e1f1b1 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -57,7 +57,9 @@ namespace quda template struct BlockTransposeKernel : BlockTransposeKernelOps::Ops { const Arg &arg; - constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { } + using typename BlockTransposeKernelOps::Ops::KernelOpsT; + template + constexpr BlockTransposeKernel(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } /** diff --git a/include/kernels/clover_deriv.cuh b/include/kernels/clover_deriv.cuh index 3623a9a40f..38ab548766 100644 --- a/include/kernels/clover_deriv.cuh +++ b/include/kernels/clover_deriv.cuh @@ -205,7 +205,8 @@ namespace quda template struct CloverDerivative : computeForceOps { const Arg &arg; - constexpr CloverDerivative(const Arg &arg) : arg(arg) {} + template + constexpr CloverDerivative(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __host__ __device__ void operator()(int x_cb, int parity, int mu) diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index 07e0dd77e7..a329fbb3dd 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -1702,7 +1702,9 @@ namespace quda { static constexpr int nFace = 1; const Arg &arg; static constexpr const char *filename() { return KERNEL_FILE; } - constexpr compute_vuv(const Arg &arg) : arg(arg) { } + using typename storeCoarseSharedAtomic_impl::Ops::KernelOpsT; + template + constexpr compute_vuv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) { } /** 3-d parallelism @@ -1735,7 +1737,9 @@ namespace quda { static constexpr int nFace = 3; const Arg &arg; static constexpr const char *filename() { return KERNEL_FILE; } - constexpr compute_vlv(const Arg &arg) : arg(arg) { } + using typename storeCoarseSharedAtomic_impl::Ops::KernelOpsT; + template + constexpr compute_vlv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) { } /** 3-d parallelism diff --git a/include/kernels/color_spinor_pack.cuh b/include/kernels/color_spinor_pack.cuh index ad5de2426d..d675408c88 100644 --- a/include/kernels/color_spinor_pack.cuh +++ b/include/kernels/color_spinor_pack.cuh @@ -292,11 +292,15 @@ namespace quda { } } - template struct GhostPacker : - std::conditional_t::Ops, NoSpecialOps> { + template using GhostPackerOps = + std::conditional_t::Ops, NoSpecialOps>; + + template struct GhostPacker : GhostPackerOps { using Arg = Arg_; const Arg &arg; - constexpr GhostPacker(const Arg &arg) : arg(arg) {} + using typename GhostPackerOps::KernelOpsT; + template + constexpr GhostPacker(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template diff --git a/include/kernels/covDev.cuh b/include/kernels/covDev.cuh index 0d3a4d328a..46405548e6 100644 --- a/include/kernels/covDev.cuh +++ b/include/kernels/covDev.cuh @@ -124,7 +124,8 @@ namespace quda dslash_default, NoSpecialOps { const Arg &arg; - constexpr covDev(const Arg &arg) : arg(arg) {} + //constexpr covDev(const Arg &arg) : arg(arg) {} + template constexpr covDev(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index 779a1833e4..a21beead5e 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -172,19 +172,23 @@ namespace quda { arg.out(x_cb, spinor_parity) = out; } }; - + + template using NdegTwistCloverApplyOps = + SpecialOps>>; + // if (!inverse) apply (Clover + i*a*gamma_5*tau_3 + b*epsilon*tau_1) to the input spinor // else apply (Clover + i*a*gamma_5*tau_3 + b*epsilon*tau_1)/(Clover^2 + a^2 - b^2) to the input spinor // noting that appropriate signs are carried by a and b depending on inverse - template struct NdegTwistCloverApply : - SpecialOps>> { + template struct NdegTwistCloverApply : NdegTwistCloverApplyOps { static constexpr int N = Arg::nColor * Arg::nSpin / 2; using real = typename Arg::real; using fermion = ColorSpinor; using half_fermion = ColorSpinor; using Mat = HMatrix; const Arg &arg; - constexpr NdegTwistCloverApply(const Arg &arg) : arg(arg) {} + using typename NdegTwistCloverApplyOps::KernelOpsT; + template + constexpr NdegTwistCloverApply(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char* filename() { return KERNEL_FILE; } template diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 715020dc32..2ab59f09b4 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -338,7 +338,9 @@ namespace quda { template struct CoarseDslash : CoarseDslashParams::Ops { using Arg = Arg_; const Arg &arg; - constexpr CoarseDslash(const Arg &arg) : arg(arg) {} + using typename CoarseDslashParams::Ops::KernelOpsT; + template + constexpr CoarseDslash(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template diff --git a/include/kernels/dslash_domain_wall_4d.cuh b/include/kernels/dslash_domain_wall_4d.cuh index 9a122d5fd8..dda8088430 100644 --- a/include/kernels/dslash_domain_wall_4d.cuh +++ b/include/kernels/dslash_domain_wall_4d.cuh @@ -28,7 +28,8 @@ namespace quda struct domainWall4D : dslash_default, NoSpecialOps { const Arg &arg; - constexpr domainWall4D(const Arg &arg) : arg(arg) {} + //constexpr domainWall4D(const Arg &arg) : arg(arg) {} + template constexpr domainWall4D(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index cffb1f8031..5280b7c98e 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -61,7 +61,9 @@ namespace quda static constexpr Dslash5Type dslash5_type = Arg::type; const Arg &arg; - constexpr domainWall4DFusedM5(const Arg &arg) : arg(arg) { } + using typename d5Params::Ops::KernelOpsT; + //constexpr domainWall4DFusedM5(const Arg &arg) : arg(arg) { } + template constexpr domainWall4DFusedM5(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_domain_wall_5d.cuh b/include/kernels/dslash_domain_wall_5d.cuh index 80038ede52..da75217c1a 100644 --- a/include/kernels/dslash_domain_wall_5d.cuh +++ b/include/kernels/dslash_domain_wall_5d.cuh @@ -26,7 +26,8 @@ namespace quda struct domainWall5D : dslash_default, NoSpecialOps { const Arg &arg; - constexpr domainWall5D(const Arg &arg) : arg(arg) {} + //constexpr domainWall5D(const Arg &arg) : arg(arg) {} + template constexpr domainWall5D(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation static constexpr QudaPCType pc_type() { return QUDA_5D_PC; } diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index 3a4f536291..f84408f8ca 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -333,7 +333,9 @@ namespace quda template struct dslash5 : d5Params::Ops { using Arg = Arg_; const Arg &arg; - constexpr dslash5(const Arg &arg) : arg(arg) { } + using typename d5Params::Ops::KernelOpsT; + template + constexpr dslash5(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } /** @@ -589,7 +591,9 @@ namespace quda template struct dslash5inv : dslash5invParams::Ops { using Arg = Arg_; const Arg &arg; - constexpr dslash5inv(const Arg &arg) : arg(arg) {} + using typename dslash5invParams::Ops::KernelOpsT; + template + constexpr dslash5inv(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } /** diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index 98dc2abff1..195a36989d 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -90,6 +90,8 @@ namespace quda } }; + template using eofa_dslash5Ops = + SpecialOps>>; /** @brief Apply the D5 operator at given site @param[in] arg Argument struct containing any meta data and accessors @@ -97,10 +99,11 @@ namespace quda @param[in] x_cb Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - template struct eofa_dslash5 : - SpecialOps>> { + template struct eofa_dslash5 : eofa_dslash5Ops { const Arg &arg; - constexpr eofa_dslash5(const Arg &arg) : arg(arg) {} + using typename eofa_dslash5Ops::KernelOpsT; + template + constexpr eofa_dslash5(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template @@ -170,6 +173,8 @@ namespace quda } }; + template using eofa_dslash5invOps = + SpecialOps>>; /** @brief Apply the M5 inverse operator at a given site on the lattice. This is the original algorithm as described in Kim and @@ -182,10 +187,11 @@ namespace quda @param[in] x_cb Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - template struct eofa_dslash5inv : - SpecialOps>> { + template struct eofa_dslash5inv : eofa_dslash5invOps { const Arg &arg; - constexpr eofa_dslash5inv(const Arg &arg) : arg(arg) {} + using typename eofa_dslash5invOps::KernelOpsT; + template + constexpr eofa_dslash5inv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index 8dba5e0a53..d2a2e47c4e 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -40,16 +40,18 @@ namespace quda using real = typename mapper::type; using Vec = ColorSpinor; using Cache = SharedMemoryCache; - using Ops = SpecialOps; + //using Ops = SpecialOps; //template - //using Ops = conditional_t,NoSpecialOps>; + using Ops = std::conditional_t,NoSpecialOps>; }; template struct nDegTwistedClover : dslash_default, nDegTwistedCloverParams::Ops { const Arg &arg; - constexpr nDegTwistedClover(const Arg &arg) : arg(arg) {} + using typename nDegTwistedCloverParams::Ops::KernelOpsT; + //constexpr nDegTwistedClover(const Arg &arg) : arg(arg) {} + template constexpr nDegTwistedClover(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index b87014b6bb..66b8382b1b 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -52,7 +52,9 @@ namespace quda struct nDegTwistedCloverPreconditioned : dslash_default, nDegTwistedCloverPreconditionedParams::Ops { const Arg &arg; - constexpr nDegTwistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + using typename nDegTwistedCloverPreconditionedParams::Ops::KernelOpsT; + //constexpr nDegTwistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr nDegTwistedCloverPreconditioned(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_ndeg_twisted_mass.cuh b/include/kernels/dslash_ndeg_twisted_mass.cuh index 6ed31be3ac..032bc40e8e 100644 --- a/include/kernels/dslash_ndeg_twisted_mass.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass.cuh @@ -26,7 +26,8 @@ namespace quda struct nDegTwistedMass : dslash_default, NoSpecialOps { const Arg &arg; - constexpr nDegTwistedMass(const Arg &arg) : arg(arg) {} + //constexpr nDegTwistedMass(const Arg &arg) : arg(arg) {} + template constexpr nDegTwistedMass(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 4907b83d14..dae62d171f 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -48,7 +48,9 @@ namespace quda struct nDegTwistedMassPreconditioned : dslash_default, nDegTwistedMassPreconditionedParams::Ops { const Arg &arg; - constexpr nDegTwistedMassPreconditioned(const Arg &arg) : arg(arg) {} + using typename nDegTwistedMassPreconditionedParams::Ops::KernelOpsT; + //constexpr nDegTwistedMassPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr nDegTwistedMassPreconditioned(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) {} constexpr int twist_pack() const { return (!Arg::asymmetric && dagger) ? 2 : 0; } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index 20842735c3..327f672890 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -173,7 +173,8 @@ namespace quda struct staggered : dslash_default, NoSpecialOps { const Arg &arg; - constexpr staggered(const Arg &arg) : arg(arg) {} + //constexpr staggered(const Arg &arg) : arg(arg) {} + template constexpr staggered(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_twisted_clover_preconditioned.cuh b/include/kernels/dslash_twisted_clover_preconditioned.cuh index c1dfbc144f..c9a7985192 100644 --- a/include/kernels/dslash_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_twisted_clover_preconditioned.cuh @@ -40,7 +40,8 @@ namespace quda struct twistedCloverPreconditioned : dslash_default, NoSpecialOps { const Arg &arg; - constexpr twistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + //constexpr twistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr twistedCloverPreconditioned(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_twisted_mass.cuh b/include/kernels/dslash_twisted_mass.cuh index c8276d01b8..da6d9f442b 100644 --- a/include/kernels/dslash_twisted_mass.cuh +++ b/include/kernels/dslash_twisted_mass.cuh @@ -24,7 +24,8 @@ namespace quda struct twistedMass : dslash_default, NoSpecialOps { const Arg &arg; - constexpr twistedMass(const Arg &arg) : arg(arg) {} + //constexpr twistedMass(const Arg &arg) : arg(arg) {} + template constexpr twistedMass(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index d6e66a7635..56557c0a13 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -135,7 +135,8 @@ namespace quda struct twistedMassPreconditioned : dslash_default, NoSpecialOps { const Arg &arg; - constexpr twistedMassPreconditioned(const Arg &arg) : arg(arg) {} + //constexpr twistedMassPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr twistedMassPreconditioned(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation constexpr int twist_pack() const { return (!Arg::asymmetric && dagger) ? 1 : 0; } diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 64c55bc587..8bc8062e6d 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -137,7 +137,8 @@ namespace quda dslash_default, NoSpecialOps { const Arg &arg; - constexpr wilson(const Arg &arg) : arg(arg) {} + //constexpr wilson(const Arg &arg) : arg(arg) {} + template constexpr wilson(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation // out(x) = M*in = (-D + m) * in(x-mu) diff --git a/include/kernels/dslash_wilson_clover.cuh b/include/kernels/dslash_wilson_clover.cuh index 22d541dad1..46d09b881d 100644 --- a/include/kernels/dslash_wilson_clover.cuh +++ b/include/kernels/dslash_wilson_clover.cuh @@ -36,7 +36,8 @@ namespace quda struct wilsonClover : dslash_default, NoSpecialOps { const Arg &arg; - constexpr wilsonClover(const Arg &arg) : arg(arg) {} + //constexpr wilsonClover(const Arg &arg) : arg(arg) {} + template constexpr wilsonClover(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh index 7b92016567..3381e92196 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh @@ -36,7 +36,8 @@ namespace quda struct cloverHasenbusch : dslash_default, NoSpecialOps { const Arg &arg; - constexpr cloverHasenbusch(const Arg &arg) : arg(arg) {} + //constexpr cloverHasenbusch(const Arg &arg) : arg(arg) {} + template constexpr cloverHasenbusch(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh index 261edfc0d1..61bb116aae 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh @@ -38,7 +38,8 @@ namespace quda struct cloverHasenbuschPreconditioned : dslash_default, NoSpecialOps { const Arg &arg; - constexpr cloverHasenbuschPreconditioned(const Arg &arg) : arg(arg) {} + //constexpr cloverHasenbuschPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr cloverHasenbuschPreconditioned(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_wilson_clover_preconditioned.cuh b/include/kernels/dslash_wilson_clover_preconditioned.cuh index 029af93027..dfe6b39dc6 100644 --- a/include/kernels/dslash_wilson_clover_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_preconditioned.cuh @@ -34,7 +34,8 @@ namespace quda struct wilsonCloverPreconditioned : dslash_default, NoSpecialOps { const Arg &arg; - constexpr wilsonCloverPreconditioned(const Arg &arg) : arg(arg) {} + //constexpr wilsonCloverPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr wilsonCloverPreconditioned(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/field_strength_tensor.cuh b/include/kernels/field_strength_tensor.cuh index 0caf5c3a7e..b1ebb453c8 100644 --- a/include/kernels/field_strength_tensor.cuh +++ b/include/kernels/field_strength_tensor.cuh @@ -179,7 +179,8 @@ namespace quda template struct ComputeFmunu : computeFmunuCoreOps { using Arg = Arg_; const Arg &arg; - constexpr ComputeFmunu(const Arg &arg) : arg(arg) {} + template + constexpr ComputeFmunu(const Arg &arg, const Ops &...ops) : computeFmunuCoreOps(ops...), arg(arg) {} static constexpr const char* filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int parity, int mu_nu) diff --git a/include/kernels/gauge_ape.cuh b/include/kernels/gauge_ape.cuh index e0a7180ba1..693733cec4 100644 --- a/include/kernels/gauge_ape.cuh +++ b/include/kernels/gauge_ape.cuh @@ -41,10 +41,12 @@ namespace quda } } }; - + template struct APE : computeStapleOps { const Arg &arg; - constexpr APE(const Arg &arg) : arg(arg) {} + //constexpr APE(const Arg &arg) : arg(arg) {} + template + constexpr APE(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char* filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int parity, int dir) diff --git a/include/kernels/gauge_fix_ovr.cuh b/include/kernels/gauge_fix_ovr.cuh index b8a40b95a8..832b0fda28 100644 --- a/include/kernels/gauge_fix_ovr.cuh +++ b/include/kernels/gauge_fix_ovr.cuh @@ -143,7 +143,9 @@ namespace quda { //template struct computeFix : SpecialOps> { template struct computeFix : computeFixOps { const Arg &arg; - constexpr computeFix(const Arg &arg) : arg(arg) {} + using typename computeFixOps::KernelOpsT; + template + constexpr computeFix(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template @@ -162,8 +164,7 @@ namespace quda { for (int dr = 0; dr < 4; dr++) p += arg.border[dr]; getCoords(x, idx, arg.X, p + parity); } else { - if (!allthreads || active) - idx = arg.borderpoints[parity][idx]; // load the lattice site assigment + if (!allthreads || active) idx = arg.borderpoints[parity][idx]; // load the lattice site assigment x[3] = idx / (X[0] * X[1] * X[2]); x[2] = (idx / (X[0] * X[1])) % X[2]; x[1] = (idx / X[0]) % X[1]; diff --git a/include/kernels/gauge_force.cuh b/include/kernels/gauge_force.cuh index 2bf8e3d0ca..996c3bdc5b 100644 --- a/include/kernels/gauge_force.cuh +++ b/include/kernels/gauge_force.cuh @@ -46,9 +46,11 @@ namespace quda { template struct GaugeForce : SpecialOps> { + using KOps = SpecialOps>; const Arg &arg; - constexpr GaugeForce(const Arg &arg) : arg(arg) {} - static constexpr const char *filename() { return KERNEL_FILE; } + template + constexpr GaugeForce(const Arg &arg, const Ops &...ops) : KOps(ops...), arg(arg) {} + static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ void operator()(int x_cb, int parity, int dir) { diff --git a/include/kernels/gauge_loop_trace.cuh b/include/kernels/gauge_loop_trace.cuh index 5ef20f8090..15648b3fd2 100644 --- a/include/kernels/gauge_loop_trace.cuh +++ b/include/kernels/gauge_loop_trace.cuh @@ -52,13 +52,16 @@ namespace quda { } }; - template struct GaugeLoop : plus, SpecialOps> + template + struct GaugeLoop : plus, KernelOps> { using reduce_t = typename Arg::reduce_t; using plus::operator(); static constexpr int reduce_block_dim = 2; // x_cb and parity are mapped to x const Arg &arg; - constexpr GaugeLoop(const Arg &arg) : arg(arg) {} + //constexpr GaugeLoop(const Arg &arg) : arg(arg) {} + template + constexpr GaugeLoop(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline reduce_t operator()(reduce_t &value, int x_cb, int parity, int path_id) diff --git a/include/kernels/gauge_stout.cuh b/include/kernels/gauge_stout.cuh index 8164edf19f..650576987d 100644 --- a/include/kernels/gauge_stout.cuh +++ b/include/kernels/gauge_stout.cuh @@ -53,7 +53,8 @@ namespace quda using Link = Matrix, Arg::nColor>; const Arg &arg; - constexpr STOUT(const Arg &arg) : arg(arg) {} + template + constexpr STOUT(const Arg &arg, const OpsArgs &...ops) : computeStapleOps(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int parity, int dir) @@ -127,9 +128,11 @@ namespace quda using real = typename Arg::Float; using Complex = complex; using Link = Matrix, Arg::nColor>; + using typename OvrImpSTOUTOps::Ops::KernelOpsT; const Arg &arg; - constexpr OvrImpSTOUT(const Arg &arg) : arg(arg) {} + template + constexpr OvrImpSTOUT(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int parity, int dir) diff --git a/include/kernels/gauge_wilson_flow.cuh b/include/kernels/gauge_wilson_flow.cuh index 82dd6ca3b0..43e25f117f 100644 --- a/include/kernels/gauge_wilson_flow.cuh +++ b/include/kernels/gauge_wilson_flow.cuh @@ -162,9 +162,11 @@ namespace quda //template struct WFlow template struct WFlow : computeStapleOpsWF::Ops { + using typename computeStapleOpsWF::Ops::KernelOpsT; using Arg = Arg_; const Arg &arg; - constexpr WFlow(const Arg &arg) : arg(arg) {} + template + constexpr WFlow(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int parity, int dir) diff --git a/include/kernels/hisq_paths_force.cuh b/include/kernels/hisq_paths_force.cuh index 46a7988180..b114eee9ec 100644 --- a/include/kernels/hisq_paths_force.cuh +++ b/include/kernels/hisq_paths_force.cuh @@ -323,7 +323,9 @@ namespace quda { static_assert(Param::nu_next_positive == -1, "nu_next_positive should be set to -1 for AllThreeAllLepageLink"); static constexpr int compute_lepage = Param::compute_lepage; - constexpr AllThreeAllLepageLink(const Param ¶m) : arg(param.arg) {} + using typename AllThreeAllLepageLinkOps::Ops::KernelOpsT; + template + constexpr AllThreeAllLepageLink(const Param ¶m, const OpsArgs &...ops) : KernelOpsT(ops...), arg(param.arg) {} constexpr static const char *filename() { return KERNEL_FILE; } /** @@ -691,7 +693,9 @@ namespace quda { static constexpr int nu_next_positive = Param::nu_next_positive; // if nu_next_positive == -1, skip static_assert(Param::compute_lepage == -1, "compute_lepage should be set to -1 for AllFiveAllSevenLink"); - constexpr AllFiveAllSevenLink(const Param ¶m) : arg(param.arg) {} + using typename AllFiveAllSevenLinkOps::Ops::KernelOpsT; + template + constexpr AllFiveAllSevenLink(const Param ¶m, const OpsArgs &...ops) : KernelOpsT(ops...), arg(param.arg) {} constexpr static const char *filename() { return KERNEL_FILE; } /** diff --git a/include/kernels/laplace.cuh b/include/kernels/laplace.cuh index f1421c451c..b563854424 100644 --- a/include/kernels/laplace.cuh +++ b/include/kernels/laplace.cuh @@ -139,7 +139,8 @@ namespace quda dslash_default, NoSpecialOps { const Arg &arg; - constexpr laplace(const Arg &arg) : arg(arg) {} + //constexpr laplace(const Arg &arg) : arg(arg) {} + template constexpr laplace(const Ftor &ftor) : arg(ftor.arg) {} static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/madwf_transfer.cuh b/include/kernels/madwf_transfer.cuh index 65427f1088..cfb8563d90 100644 --- a/include/kernels/madwf_transfer.cuh +++ b/include/kernels/madwf_transfer.cuh @@ -102,7 +102,9 @@ namespace quda template struct Transfer5D : Transfer5DParams::Ops { const Arg &arg; - constexpr Transfer5D(const Arg &arg) : arg(arg) { } + using typename Transfer5DParams::Ops::KernelOpsT; + template + constexpr Transfer5D(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } /** diff --git a/include/kernels/restrictor.cuh b/include/kernels/restrictor.cuh index 59810f174c..115ea5ed99 100644 --- a/include/kernels/restrictor.cuh +++ b/include/kernels/restrictor.cuh @@ -115,7 +115,9 @@ namespace quda { using vector = typename RestrictorParams::vector; using BlockReduce_t = typename RestrictorParams::BlockReduce_t; const Arg &arg; - constexpr Restrictor(const Arg &arg) : arg(arg) {} + using typename SpecialOps::KernelOpsT; + template + constexpr Restrictor(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } template diff --git a/include/targets/generic/helpers.h b/include/targets/generic/helpers.h index eeb6a65533..ce13676e6d 100644 --- a/include/targets/generic/helpers.h +++ b/include/targets/generic/helpers.h @@ -24,6 +24,12 @@ namespace quda } }; + struct SizeZ { + static constexpr unsigned int size(dim3 block) { + return block.z; + } + }; + template struct SizeDims { static constexpr unsigned int size(dim3 block) { dim3 dims = D::dims(block); diff --git a/include/targets/generic/special_ops.h b/include/targets/generic/special_ops.h index ee1906bc74..1584a7dd0c 100644 --- a/include/targets/generic/special_ops.h +++ b/include/targets/generic/special_ops.h @@ -37,16 +37,29 @@ namespace quda { // alternative to SpecialOps struct NoSpecialOps { using SpecialOpsT = NoSpecialOps; + using KernelOpsT = NoSpecialOps; }; // SpecialOps forward declaration and base type template struct SpecialOps; + template using KernelOps = SpecialOps; template struct SpecialOps_Base { using SpecialOpsT = SpecialOps; + using KernelOpsT = SpecialOps; }; //template struct SpecialOps : SpecialOpsTarget { // using SpecialOpsT = SpecialOps; //}; + // getSpecialOps + template struct getSpecialOpsS { using type = NoSpecialOps; }; + template struct getSpecialOpsS> { + using type = typename T::SpecialOpsT; + }; + template struct getSpecialOpsS,void> { + using type = SpecialOps; + }; + template using getSpecialOps = typename getSpecialOpsS::type; + // hasSpecialOp: checks if first type matches any of the op // > template static constexpr bool hasSpecialOp = false; @@ -54,14 +67,14 @@ namespace quda { static constexpr bool hasSpecialOp> = ( std::is_same_v || ... ); //template void checkSpecialOps() { static_assert(hasSpecialOp); } - template void checkSpecialOps(const Ops &) { - static_assert(hasSpecialOp); + //template void checkSpecialOps(const Ops &) { + //static_assert(hasSpecialOp); + //} + template void checkSpecialOps(const Ops &) { + static_assert((hasSpecialOp || ...)); } - - - // OLD template struct op_Concurrent {}; // set of op types used concurrently (needs separate resources) @@ -94,16 +107,6 @@ namespace quda { template using only_SharedMemStatic = only_SharedMemory>; template using only_Concurrent = SpecialOps>; - // getSpecialOps - template struct getSpecialOpsS { using type = NoSpecialOps; }; - template struct getSpecialOpsS> { - using type = typename T::SpecialOpsT; - }; - template struct getSpecialOpsS,void> { - using type = SpecialOps; - }; - template using getSpecialOps = typename getSpecialOpsS::type; - // explicitSpecialOps template struct explicitSpecialOpsS : std::false_type {}; template diff --git a/include/targets/sycl/block_reduce_helper.h b/include/targets/sycl/block_reduce_helper.h index fbed3045ae..a91d8ea9b4 100644 --- a/include/targets/sycl/block_reduce_helper.h +++ b/include/targets/sycl/block_reduce_helper.h @@ -4,6 +4,7 @@ #include #include #include +#include /** @file block_reduce_helper.h @@ -132,21 +133,26 @@ namespace quda */ #define DYNAMIC_SLM template - struct block_reduceW { + //struct block_reduceW { + struct block_reduceW : SharedMemory { + using Smem = SharedMemory; + //using Smem::shared_mem_size; #ifdef DYNAMIC_SLM using opSmem = op_SharedMemory; + //using opSmem = SharedMemory; using dependencies = op_Sequential; using dependentOps = SpecialOps; - template - static constexpr size_t shared_mem_size(dim3 block, Arg &...arg) { - return opSizeBlockDivWarp::size(block, arg...); - } + //template + //static constexpr size_t shared_mem_size(dim3 block, Arg &...arg) { + //return opSizeBlockDivWarp::size(block, arg...); + //} #else #endif using BlockReduce_t = BlockReduce; - dependentOps ops; + //dependentOps ops; template - inline block_reduceW(S &ops) : ops(getDependentOps(ops)) {}; + //inline block_reduceW(S &ops) : ops(getDependentOps(ops)) {}; + inline block_reduceW(S &ops) : Smem(ops) {}; template struct warp_reduce_param { static constexpr int width = width_; @@ -183,7 +189,8 @@ namespace quda //__shared__ T storage[max_items]; #ifdef DYNAMIC_SLM - auto storage = getSharedMemPtr(ops); + //auto storage = getSharedMemPtr(ops); + auto storage = Smem::sharedMem(); #else static_assert(sizeof(T[max_items])<=device::shared_memory_size(), "Block reduce shared mem size too large"); auto mem = sycl::ext::oneapi::group_local_memory_for_overwrite(getGroup()); @@ -192,7 +199,8 @@ namespace quda // if first thread in warp, write result to shared memory if (thread_idx % device::warp_size() == 0) storage[batch * warp_items + warp_idx] = value; - blockSync(ops); + //blockSync(ops); + __syncthreads(); // whether to use the first warp or first thread for the final reduction constexpr bool final_warp_reduction = true; @@ -216,7 +224,8 @@ namespace quda if (all) { if (thread_idx == 0) storage[batch * warp_items + 0] = value; - blockSync(ops); + //blockSync(ops); + __syncthreads(); value = storage[batch * warp_items + 0]; } diff --git a/include/targets/sycl/block_reduction_kernel.h b/include/targets/sycl/block_reduction_kernel.h index bd4c5cd0dd..bec9c85905 100644 --- a/include/targets/sycl/block_reduction_kernel.h +++ b/include/targets/sycl/block_reduction_kernel.h @@ -95,6 +95,7 @@ namespace quda const unsigned int k = globalIdZ; if (k >= arg.threads.z) return; +#if 0 Functor f(arg); if constexpr (hasSpecialOps>) { f.setNdItem(ndi); @@ -102,6 +103,10 @@ namespace quda if constexpr (needsSharedMem>) { f.setSharedMem(smem...); } +#else + Ftor> f(arg, ndi, smem...); +#endif + f(block_idx, thread_idx); } template