Skip to content

Commit

Permalink
Apply traits to operations (#120)
Browse files Browse the repository at this point in the history
* Strict type checks in uunary operations

* Strict type checks in binary operations

* Strict tpe checks in Plan creation

* fix: missing execution space type

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jul 23, 2024
1 parent bab1620 commit 73b2c77
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 462 deletions.
41 changes: 21 additions & 20 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <Kokkos_Core.hpp>
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
Expand Down Expand Up @@ -131,16 +132,6 @@ void roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"fftshift_impl: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"fftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"fftshift_impl: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"fftshift_impl: Rank of View must be larger thane "
"or equal to the Rank of shift axes.");
Expand All @@ -151,16 +142,6 @@ void fftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift_impl(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(Kokkos::is_view<ViewType>::value,
"ifftshift_impl: ViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<ViewType>,
"ifftshift_impl: ViewType must be either LayoutLeft or LayoutRight.");
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename ViewType::memory_space>::accessible,
"ifftshift_impl: execution_space cannot access data in ViewType");

static_assert(ViewType::rank() >= DIM,
"ifftshift_impl: Rank of View must be larger "
"thane or equal to the Rank of shift axes.");
Expand Down Expand Up @@ -243,6 +224,11 @@ auto rfftfreq(const ExecutionSpace&, const std::size_t n,
template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
Expand All @@ -262,6 +248,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"fftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
KokkosFFT::Impl::fftshift_impl(exec_space, inout, axes);
}

Expand All @@ -273,6 +264,11 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
std::optional<int> axes = std::nullopt) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
if (axes) {
axis_type<1> _axes{axes.value()};
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
Expand All @@ -292,6 +288,11 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
template <typename ExecutionSpace, typename ViewType, std::size_t DIM = 1>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
axis_type<DIM> axes) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"ifftshift: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, axes);
}
} // namespace KokkosFFT
Expand Down
77 changes: 15 additions & 62 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"
Expand Down Expand Up @@ -158,33 +159,14 @@ class Plan {
OutViewType& out, KokkosFFT::Direction direction, int axis,
std::optional<std::size_t> n = std::nullopt)
: m_exec_space(exec_space), m_axes({axis}), m_direction(direction) {
static_assert(Kokkos::is_view<InViewType>::value,
"Plan::Plan: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"Plan::Plan: OutViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"Plan::Plan: InViewType and OutViewType must have "
"the same rank.");
static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"Plan::Plan: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"Plan::Plan: InViewType and OutViewType must have the same base "
"floating point type (float/double), the same layout "
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");

if (KokkosFFT::Impl::is_real_v<in_value_type> &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -230,34 +212,14 @@ class Plan {
OutViewType& out, KokkosFFT::Direction direction,
axis_type<DIM> axes, shape_type<DIM> s = {0})
: m_exec_space(exec_space), m_axes(axes), m_direction(direction) {
static_assert(Kokkos::is_view<InViewType>::value,
"Plan::Plan: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"Plan::Plan: OutViewType is not a Kokkos::View.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"Plan::Plan: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(
KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"Plan::Plan: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"Plan::Plan: InViewType and OutViewType must have "
"the same rank.");

static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"Plan::Plan: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"Plan::Plan: InViewType and OutViewType must have the same base "
"floating point type (float/double), the same layout "
"(LayoutLeft/LayoutRight), "
"and the same rank. ExecutionSpace must be accessible to the data in "
"InViewType and OutViewType.");

if (std::is_floating_point<in_value_type>::value &&
m_direction != KokkosFFT::Direction::forward) {
Expand Down Expand Up @@ -302,15 +264,6 @@ class Plan {
/// \param out [in] Ouput data
template <typename InViewType2, typename OutViewType2>
void good(const InViewType2& in, const OutViewType2& out) const {
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename InViewType2::memory_space>::accessible,
"Plan::good: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType2::memory_space>::accessible,
"Plan::good: execution_space cannot access data in OutViewType");

using nonConstInViewType2 = std::remove_cv_t<InViewType2>;
using nonConstOutViewType2 = std::remove_cv_t<OutViewType2>;
static_assert(std::is_same_v<nonConstInViewType2, nonConstInViewType>,
Expand Down
Loading

0 comments on commit 73b2c77

Please sign in to comment.