Skip to content

Commit

Permalink
cleanup kernel ops
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Jan 8, 2024
1 parent ae7072d commit fd1393e
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 108 deletions.
2 changes: 1 addition & 1 deletion include/targets/cuda/block_reduce_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ namespace quda

template <typename T, int block_dim, int batch_size>
struct block_reduce {
template <typename Ops> HostDevice inline block_reduce(Ops &) {};
template <typename Ops> HostDevice inline block_reduce(const Ops &) {};
template <typename ...Arg> static constexpr size_t shared_mem_size(dim3 block) {
return SizeBlockDivWarp::size(block) * sizeof(T);
}
Expand Down
26 changes: 0 additions & 26 deletions include/targets/cuda/kernel_ops_target.h

This file was deleted.

131 changes: 76 additions & 55 deletions include/targets/generic/kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,16 @@ namespace quda
@brief Used to declare an object of size equal to the size of the block Z dimension.
*/
struct SizeZ {
static constexpr unsigned int size(dim3 block) {
return block.z;
}
static constexpr unsigned int size(dim3 block) { return block.z; }
};

/**
@brief Used to declare an object of size equal to the block size divided by the warp size.
*/
struct SizeBlockDivWarp {
static constexpr unsigned int size(dim3 b) {
return (b.x * b.y * b.z + device::warp_size() - 1)/device::warp_size();
static constexpr unsigned int size(dim3 b)
{
return (b.x * b.y * b.z + device::warp_size() - 1) / device::warp_size();
}
};

Expand Down Expand Up @@ -69,94 +68,116 @@ namespace quda
};

// forward declare sharedMemSize
template <typename T, typename... Arg> static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg);
template <typename T, typename... Arg> static constexpr unsigned int sharedMemSize(dim3 block, const Arg &...arg);

// alternative to KernelOps
struct NoKernelOps {
using KernelOpsT = NoKernelOps;
};
// KernelOps forward declaration and base type
/**
@brief KernelOps forward declaration and KernelOps_Base type,
which the target specific KernelOps should inherit from. Kernels
can inherit from KernelOps to tag kernels with operations that
may need special resources like shared memory, or have other
special requirements (e.g. using block sync which may require
special handling for some targets). The template arguments,
T..., specify the types of the operations the kernel uses.
*/
template <typename... T> struct KernelOps;
template <typename... T> struct KernelOps_Base {
using KernelOpsT = KernelOps<T...>;
template <typename... Arg> static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) {
template <typename... Arg> static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg)
{
return sharedMemSize<KernelOpsT>(block, arg...);
}
};

/**
@brief Used to get KernelOps from a kernel type. Checks for the existence of KernelOpsT.
@brief Type that specifies a kernel does not have any operations
that need tagging. This can be used as an alternative in cases
where the operations are only conditionally used.
*/
struct NoKernelOps {
using KernelOpsT = NoKernelOps;
};

/**
@brief Used to get KernelOps from a kernel type. Checks for the
existence of KernelOpsT.
*/
template <typename T, typename U = void> struct getKernelOpsS { using type = NoKernelOps; };
template <typename T> struct getKernelOpsS<T,std::conditional_t<true,void,typename T::KernelOpsT>> {
template <typename T, typename U = void> struct getKernelOpsS {
using type = NoKernelOps;
};
template <typename T> struct getKernelOpsS<T, std::conditional_t<true, void, typename T::KernelOpsT>> {
using type = typename T::KernelOpsT;
};
template <typename T> using getKernelOps = typename getKernelOpsS<T>::type;

// hasKernelOp: checks if first type matches any of the op
// <op, KernelOps<ops...>>
/**
@brief Checks whether a kernel type is tagged with any KernelOps.
*/
template <typename T> static constexpr bool hasKernelOps = !std::is_same_v<getKernelOps<T>, NoKernelOps>;

/**
@brief Checks if first template type matches any of the ops in
the second template type, which is a KernelOps template type.
*/
template <typename T, typename U> static constexpr bool hasKernelOp = false;
template <typename T, typename... U>
static constexpr bool hasKernelOp<T,KernelOps<U...>> = ( std::is_same_v<T,U> || ... );
static constexpr bool hasKernelOp<T, KernelOps<U...>> = (std::is_same_v<T, U> || ...);

// checkKernelOps
template <typename... T, typename Ops> static constexpr void checkKernelOps(const Ops &) {
static_assert((hasKernelOp<T,typename Ops::KernelOpsT> || ...));
/**
@brief Function to statically check if the passed kernel functor was tagged with all the
operations in the template parameters.
*/
template <typename... T, typename Ops> static constexpr void checkKernelOps(const Ops &)
{
static_assert((hasKernelOp<T, typename Ops::KernelOpsT> || ...));
}

// hasKernelOps
template <typename T> inline constexpr bool hasKernelOps = !std::is_same_v<getKernelOps<T>,NoKernelOps>;

// combineOps
template <typename... T> struct combineOpsS {};
template <typename... T> struct combineOpsS<NoKernelOps,KernelOps<T...>> {
using type = KernelOps<T...>; };
template <typename ... T> struct combineOpsS<KernelOps<T...>,NoKernelOps> {
using type = KernelOps<T...>; };
template <typename ...T, typename ...U> struct combineOpsS<KernelOps<T...>,KernelOps<U...>> {
using type = KernelOps<T..., U...>; };
/**
@brief Helper to combine two KernelOps or NoKernelOps types.
*/
template <typename... T> struct combineOpsS {
};
template <typename... T> struct combineOpsS<NoKernelOps, KernelOps<T...>> {
using type = KernelOps<T...>;
};
template <typename... T> struct combineOpsS<KernelOps<T...>, NoKernelOps> {
using type = KernelOps<T...>;
};
template <typename... T, typename... U> struct combineOpsS<KernelOps<T...>, KernelOps<U...>> {
using type = KernelOps<T..., U...>;
};
template <typename T, typename U> using combineOps = typename combineOpsS<T, U>::type;

// sharedMemSize
/**
@brief Gets the total shared memory size needed for a set of
kernel operations. If any ops types have an offset for the
shared memory, then the offset is included in the size.
*/
template <typename T> struct sharedMemSizeS {
template <typename ...Arg>
static constexpr unsigned int size(dim3 block, Arg &...arg) {
template <typename... Arg> static constexpr unsigned int size(dim3 block, const Arg &...arg)
{
return T::shared_mem_size(block, arg...);
}
};
template <> struct sharedMemSizeS<NoKernelOps> {
template <typename ...Arg>
static constexpr unsigned int size(dim3, Arg &...) {
return 0;
}
template <typename... Arg> static constexpr unsigned int size(dim3, const Arg &...) { return 0; }
};
template <typename ...T> struct sharedMemSizeS<KernelOps<T...>> {
template <typename ...Arg>
static constexpr unsigned int size(dim3 block, Arg &...arg) {
template <typename... T> struct sharedMemSizeS<KernelOps<T...>> {
template <typename... Arg> static constexpr unsigned int size(dim3 block, const Arg &...arg)
{
return std::max({sharedMemSizeS<T>::size(block, arg...)...});
}
};
template <typename T, typename... Arg> static constexpr unsigned int sharedMemSize(dim3 block, Arg &...arg) {
template <typename T, typename... Arg> static constexpr unsigned int sharedMemSize(dim3 block, const Arg &...arg)
{
return sharedMemSizeS<T>::size(block, arg...);
}

// forward declarations of op types
// forward declarations of op types to be defined by target
struct op_blockSync;
template <typename T> struct op_warp_combine;

// only types for convenience
using only_blockSync = KernelOps<op_blockSync>;
template <typename T> using only_warp_combine = KernelOps<op_warp_combine<T>>;

// explicitKernelOps
//template <typename T, typename U = void> struct explicitKernelOpsS : std::false_type {};
//template <typename T>
//struct explicitKernelOpsS<T,std::conditional_t<true,void,typename T::KernelOpsT>> : std::true_type {};
//template <typename T> inline constexpr bool explicitKernelOps = explicitKernelOpsS<T>::value;

// checkKernelOp
//template <typename T, typename... U> static constexpr void checkKernelOp() {
// static_assert((std::is_same_v<T,U> || ...) == true);
//}

} // namespace quda
26 changes: 0 additions & 26 deletions include/targets/hip/kernel_ops_target.h

This file was deleted.

0 comments on commit fd1393e

Please sign in to comment.