diff --git a/include/targets/cuda/block_reduce_helper.h b/include/targets/cuda/block_reduce_helper.h index 59907822e8..6e38ed7696 100644 --- a/include/targets/cuda/block_reduce_helper.h +++ b/include/targets/cuda/block_reduce_helper.h @@ -302,7 +302,7 @@ namespace quda template struct block_reduce { - template HostDevice inline block_reduce(Ops &) {}; + template HostDevice inline block_reduce(const Ops &) {}; template static constexpr size_t shared_mem_size(dim3 block) { return SizeBlockDivWarp::size(block) * sizeof(T); } diff --git a/include/targets/cuda/kernel_ops_target.h b/include/targets/cuda/kernel_ops_target.h deleted file mode 100644 index ce567513bc..0000000000 --- a/include/targets/cuda/kernel_ops_target.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once -#include - -namespace quda { - - // KernelOps - template - struct KernelOps : KernelOps_Base { - //template constexpr void setKernelOps(const KernelOps &) { - // static_assert(std::is_same_v,KernelOps>); - //} - }; - - // op implementations - struct op_blockSync { - template - static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } - }; - - template - struct op_warp_combine { - template - static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } - }; - -} diff --git a/include/targets/generic/kernel_ops.h b/include/targets/generic/kernel_ops.h index 62f72f02d6..6c83f5fb6e 100644 --- a/include/targets/generic/kernel_ops.h +++ b/include/targets/generic/kernel_ops.h @@ -21,17 +21,16 @@ namespace quda @brief Used to declare an object of size equal to the size of the block Z dimension. */ struct SizeZ { - static constexpr unsigned int size(dim3 block) { - return block.z; - } + static constexpr unsigned int size(dim3 block) { return block.z; } }; /** @brief Used to declare an object of size equal to the block size divided by the warp size. */ struct SizeBlockDivWarp { - static constexpr unsigned int size(dim3 b) { - return (b.x * b.y * b.z + device::warp_size() - 1)/device::warp_size(); + static constexpr unsigned int size(dim3 b) + { + return (b.x * b.y * b.z + device::warp_size() - 1) / device::warp_size(); } }; @@ -69,78 +68,111 @@ namespace quda }; // forward declare sharedMemSize - template static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg); + template static constexpr unsigned int sharedMemSize(dim3 block, const Arg &...arg); - // alternative to KernelOps - struct NoKernelOps { - using KernelOpsT = NoKernelOps; - }; - // KernelOps forward declaration and base type + /** + @brief KernelOps forward declaration and KernelOps_Base type, + which the target specific KernelOps should inherit from. Kernels + can inherit from KernelOps to tag kernels with operations that + may need special resources like shared memory, or have other + special requirements (e.g. using block sync which may require + special handling for some targets). The template arguments, + T..., specify the types of the operations the kernel uses. + */ template struct KernelOps; template struct KernelOps_Base { using KernelOpsT = KernelOps; - template static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) { + template static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) + { return sharedMemSize(block, arg...); } }; /** - @brief Used to get KernelOps from a kernel type. Checks for the existence of KernelOpsT. + @brief Type that specifies a kernel does not have any operations + that need tagging. This can be used as an alternative in cases + where the operations are only conditionally used. + */ + struct NoKernelOps { + using KernelOpsT = NoKernelOps; + }; + + /** + @brief Used to get KernelOps from a kernel type. Checks for the + existence of KernelOpsT. */ - template struct getKernelOpsS { using type = NoKernelOps; }; - template struct getKernelOpsS> { + template struct getKernelOpsS { + using type = NoKernelOps; + }; + template struct getKernelOpsS> { using type = typename T::KernelOpsT; }; template using getKernelOps = typename getKernelOpsS::type; - // hasKernelOp: checks if first type matches any of the op - // > + /** + @brief Checks whether a kernel type is tagged with any KernelOps. + */ + template static constexpr bool hasKernelOps = !std::is_same_v, NoKernelOps>; + + /** + @brief Checks if first template type matches any of the ops in + the second template type, which is a KernelOps template type. + */ template static constexpr bool hasKernelOp = false; template - static constexpr bool hasKernelOp> = ( std::is_same_v || ... ); + static constexpr bool hasKernelOp> = (std::is_same_v || ...); - // checkKernelOps - template static constexpr void checkKernelOps(const Ops &) { - static_assert((hasKernelOp || ...)); + /** + @brief Function to statically check if the passed kernel functor was tagged with all the + operations in the template parameters. + */ + template static constexpr void checkKernelOps(const Ops &) + { + static_assert((hasKernelOp || ...)); } - // hasKernelOps - template inline constexpr bool hasKernelOps = !std::is_same_v,NoKernelOps>; - - // combineOps - template struct combineOpsS {}; - template struct combineOpsS> { - using type = KernelOps; }; - template struct combineOpsS,NoKernelOps> { - using type = KernelOps; }; - template struct combineOpsS,KernelOps> { - using type = KernelOps; }; + /** + @brief Helper to combine two KernelOps or NoKernelOps types. + */ + template struct combineOpsS { + }; + template struct combineOpsS> { + using type = KernelOps; + }; + template struct combineOpsS, NoKernelOps> { + using type = KernelOps; + }; + template struct combineOpsS, KernelOps> { + using type = KernelOps; + }; template using combineOps = typename combineOpsS::type; - // sharedMemSize + /** + @brief Gets the total shared memory size needed for a set of + kernel operations. If any ops types have an offset for the + shared memory, then the offset is included in the size. + */ template struct sharedMemSizeS { - template - static constexpr unsigned int size(dim3 block, Arg &...arg) { + template static constexpr unsigned int size(dim3 block, const Arg &...arg) + { return T::shared_mem_size(block, arg...); } }; template <> struct sharedMemSizeS { - template - static constexpr unsigned int size(dim3, Arg &...) { - return 0; - } + template static constexpr unsigned int size(dim3, const Arg &...) { return 0; } }; - template struct sharedMemSizeS> { - template - static constexpr unsigned int size(dim3 block, Arg &...arg) { + template struct sharedMemSizeS> { + template static constexpr unsigned int size(dim3 block, const Arg &...arg) + { return std::max({sharedMemSizeS::size(block, arg...)...}); } }; - template static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg) { + template static constexpr unsigned int sharedMemSize(dim3 block, const Arg &...arg) + { return sharedMemSizeS::size(block, arg...); } - // forward declarations of op types + // forward declarations of op types to be defined by target struct op_blockSync; template struct op_warp_combine; @@ -148,15 +180,4 @@ namespace quda using only_blockSync = KernelOps; template using only_warp_combine = KernelOps>; - // explicitKernelOps - //template struct explicitKernelOpsS : std::false_type {}; - //template - //struct explicitKernelOpsS> : std::true_type {}; - //template inline constexpr bool explicitKernelOps = explicitKernelOpsS::value; - - // checkKernelOp - //template static constexpr void checkKernelOp() { - // static_assert((std::is_same_v || ...) == true); - //} - } // namespace quda diff --git a/include/targets/hip/kernel_ops_target.h b/include/targets/hip/kernel_ops_target.h deleted file mode 100644 index b9308abda3..0000000000 --- a/include/targets/hip/kernel_ops_target.h +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once -#include - -namespace quda { - - // KernelOps - template - struct KernelOps : KernelOps_Base { - template constexpr void setKernelOps(const KernelOps &) { - static_assert(std::is_same_v,KernelOps>); - } - }; - - // op implementations - struct op_blockSync { - template - static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } - }; - - template - struct op_warp_combine { - template - static constexpr unsigned int shared_mem_size(dim3, Arg &...) { return 0; } - }; - -}