Skip to content

Commit

Permalink
Set default axes for N-dimensional fft APIs (#141)
Browse files Browse the repository at this point in the history
* improve index_sequence to return const std::array

* add default axes to ND fft functions

* update docstring for fftn

* Simplify index_sqeuence implementation

* Improve tests for index_sequence

* Recover fftn tests with given axes

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Sep 26, 2024
1 parent 2fc134d commit eaf9e85
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 45 deletions.
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ void fftshift(const ExecutionSpace& exec_space, ViewType& inout,
} else {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> _axes = KokkosFFT::Impl::index_sequence<rank>(start);
auto _axes = KokkosFFT::Impl::index_sequence<int, rank, start>();
KokkosFFT::Impl::fftshift_impl(exec_space, inout, _axes);
}
}
Expand Down Expand Up @@ -281,7 +281,7 @@ void ifftshift(const ExecutionSpace& exec_space, ViewType& inout,
} else {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> _axes = KokkosFFT::Impl::index_sequence<rank>(start);
auto _axes = KokkosFFT::Impl::index_sequence<int, rank, start>();
KokkosFFT::Impl::ifftshift_impl(exec_space, inout, _axes);
}
}
Expand Down
18 changes: 8 additions & 10 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,14 @@ std::size_t get_index(ContainerType& values, const ValueType& value) {
return it - values.begin();
}

template <typename T, std::size_t... I>
std::array<T, sizeof...(I)> make_sequence_array(std::index_sequence<I...>) {
return std::array<T, sizeof...(I)>{{I...}};
}

template <int N, typename T>
std::array<T, N> index_sequence(T const& start) {
auto sequence = make_sequence_array<T>(std::make_index_sequence<N>());
std::transform(sequence.begin(), sequence.end(), sequence.begin(),
[=](const T sequence) -> T { return start + sequence; });
template <typename IntType, std::size_t N, IntType start>
constexpr std::array<IntType, N> index_sequence() {
static_assert(std::is_integral_v<IntType> && std::is_signed_v<IntType>,
"index_sequence: IntType must be a signed integer type.");
std::array<IntType, N> sequence{};
for (std::size_t i = 0; i < N; ++i) {
sequence[i] = start + static_cast<IntType>(i);
}
return sequence;
}

Expand Down
32 changes: 31 additions & 1 deletion common/unit_test/Test_Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,34 @@ TEST(ExtractExtents, 1Dto8D) {
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view6D), ref_extents6D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view7D), ref_extents7D);
EXPECT_EQ(KokkosFFT::Impl::extract_extents(view8D), ref_extents8D);
}
}

TEST(IndexSequence, 3Dto5D) {
using View3Dtype = Kokkos::View<double***, execution_space>;
using View4Dtype = Kokkos::View<double****, execution_space>;
using View5Dtype = Kokkos::View<double*****, execution_space>;

constexpr std::size_t DIM = 3;
std::size_t n1 = 1, n2 = 1, n3 = 2, n4 = 3, n5 = 5;
View3Dtype view3D("view3D", n1, n2, n3);
View4Dtype view4D("view4D", n1, n2, n3, n4);
View5Dtype view5D("view5D", n1, n2, n3, n4, n5);
constexpr int start0 = -static_cast<int>(View3Dtype::rank());
constexpr int start1 = -static_cast<int>(View4Dtype::rank());
constexpr int start2 = -static_cast<int>(View5Dtype::rank());

constexpr auto default_axes0 =
KokkosFFT::Impl::index_sequence<int, DIM, start0>();
constexpr auto default_axes1 =
KokkosFFT::Impl::index_sequence<int, DIM, start1>();
constexpr auto default_axes2 =
KokkosFFT::Impl::index_sequence<int, DIM, start2>();

std::array<int, DIM> ref_axes0 = {-3, -2, -1};
std::array<int, DIM> ref_axes1 = {-4, -3, -2};
std::array<int, DIM> ref_axes2 = {-5, -4, -3};

EXPECT_EQ(default_axes0, ref_axes0);
EXPECT_EQ(default_axes1, ref_axes1);
EXPECT_EQ(default_axes2, ref_axes2);
}
56 changes: 32 additions & 24 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,15 +470,17 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
/// \param exec_space [in] Kokkos execution space
/// \param in [in] Input data (complex)
/// \param out [out] Ouput data (complex)
/// \param axes [in] Axes over which FFT is performed
/// \param axes [in] Axes over which FFT is performed (optional)
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void fftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void fftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -503,15 +505,17 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param exec_space [in] Kokkos execution space
/// \param in [in] Input data (complex)
/// \param out [out] Ouput data (complex)
/// \param axes [in] Axes over which FFT is performed
/// \param axes [in] Axes over which FFT is performed (optional)
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void ifftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -538,15 +542,17 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param exec_space [in] Kokkos execution space
/// \param in [in] Input data (real)
/// \param out [out] Ouput data (complex)
/// \param axes [in] Axes over which FFT is performed
/// \param axes [in] Axes over which FFT is performed (optional)
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void rfftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -579,15 +585,17 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
/// \param exec_space [in] Kokkos execution space
/// \param in [in] Input data (complex)
/// \param out [out] Ouput data (real)
/// \param axes [in] Axes over which FFT is performed
/// \param axes [in] Axes over which FFT is performed (optional)
/// \param norm [in] How the normalization is applied (optional)
/// \param s [in] Shape of the transformed axis of the output (optional)
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, axis_type<DIM> axes,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {0}) {
std::size_t DIM = InViewType::rank()>
void irfftn(
const ExecutionSpace& exec_space, const InViewType& in, OutViewType& out,
axis_type<DIM> axes =
KokkosFFT::Impl::index_sequence<int, DIM, -static_cast<int>(DIM)>(),
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward,
shape_type<DIM> s = {}) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down
51 changes: 43 additions & 8 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,7 +2304,7 @@ void test_fftn_2dfft_2dview() {

ComplexView2DType x("x", n0, n1);
ComplexView2DType out("out", n0, n1), out1("out1", n0, n1),
out2("out2", n0, n1);
out2("out2", n0, n1), out_no_axes("out_no_axes", n0, n1);
ComplexView2DType out_b("out_b", n0, n1), out_o("out_o", n0, n1),
out_f("out_f", n0, n1);

Expand All @@ -2322,6 +2322,8 @@ void test_fftn_2dfft_2dview() {

using axes_type = KokkosFFT::axis_type<2>;
axes_type axes = {-2, -1};
KokkosFFT::fftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out_b, axes,
Expand All @@ -2335,6 +2337,7 @@ void test_fftn_2dfft_2dview() {
multiply(out_f, static_cast<T>(n0 * n1));

EXPECT_TRUE(allclose(out, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out2, 1.e-5, 1.e-6));
Expand Down Expand Up @@ -2368,7 +2371,7 @@ void test_ifftn_2dfft_2dview() {

ComplexView2DType x("x", n0, n1);
ComplexView2DType out("out", n0, n1), out1("out1", n0, n1),
out2("out2", n0, n1);
out2("out2", n0, n1), out_no_axes("out_no_axes", n0, n1);
ComplexView2DType out_b("out_b", n0, n1), out_o("out_o", n0, n1),
out_f("out_f", n0, n1);

Expand All @@ -2387,6 +2390,8 @@ void test_ifftn_2dfft_2dview() {
KokkosFFT::ifft(execution_space(), out1, out2,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::ifftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out_b, axes,
Expand All @@ -2400,6 +2405,7 @@ void test_ifftn_2dfft_2dview() {
multiply(out_f, 1.0 / static_cast<T>(n0 * n1));

EXPECT_TRUE(allclose(out, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out2, 1.e-5, 1.e-6));
Expand Down Expand Up @@ -2434,7 +2440,7 @@ void test_rfftn_2dfft_2dview() {

RealView2DType x("x", n0, n1), x_ref("x_ref", n0, n1);
ComplexView2DType out("out", n0, n1 / 2 + 1), out1("out1", n0, n1 / 2 + 1),
out2("out2", n0, n1 / 2 + 1);
out2("out2", n0, n1 / 2 + 1), out_no_axes("out_no_axes", n0, n1 / 2 + 1);
ComplexView2DType out_b("out_b", n0, n1 / 2 + 1),
out_o("out_o", n0, n1 / 2 + 1), out_f("out_f", n0, n1 / 2 + 1);

Expand All @@ -2451,6 +2457,10 @@ void test_rfftn_2dfft_2dview() {
KokkosFFT::fft(execution_space(), out1, out2,
KokkosFFT::Normalization::backward, /*axis=*/0);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
Expand All @@ -2471,6 +2481,7 @@ void test_rfftn_2dfft_2dview() {
multiply(out_f, static_cast<T>(n0 * n1));

EXPECT_TRUE(allclose(out, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out2, 1.e-5, 1.e-6));
Expand Down Expand Up @@ -2514,7 +2525,7 @@ void test_irfftn_2dfft_2dview() {
ComplexView2DType out1("out1", n0, n1 / 2 + 1);
RealView2DType out2("out2", n0, n1), out("out", n0, n1);
RealView2DType out_b("out_b", n0, n1), out_o("out_o", n0, n1),
out_f("out_f", n0, n1);
out_f("out_f", n0, n1), out_no_axes("out_no_axes", n0, n1);

const Kokkos::complex<T> I(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Expand All @@ -2530,6 +2541,11 @@ void test_irfftn_2dfft_2dview() {
KokkosFFT::irfft(execution_space(), out1, out2,
KokkosFFT::Normalization::backward, /*axis=*/1);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(
execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
Expand All @@ -2550,6 +2566,7 @@ void test_irfftn_2dfft_2dview() {
multiply(out_f, 1.0 / static_cast<T>(n0 * n1));

EXPECT_TRUE(allclose(out, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out2, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out2, 1.e-5, 1.e-6));
Expand Down Expand Up @@ -2678,7 +2695,7 @@ void test_fftn_3dfft_3dview(T atol = 1.0e-6) {
ComplexView3DType out("out", n0, n1, n2), out1("out1", n0, n1, n2),
out2("out2", n0, n1, n2), out3("out3", n0, n1, n2);
ComplexView3DType out_b("out_b", n0, n1, n2), out_o("out_o", n0, n1, n2),
out_f("out_f", n0, n1, n2);
out_f("out_f", n0, n1, n2), out_no_axes("out_no_axes", n0, n1, n2);

const Kokkos::complex<T> I(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Expand All @@ -2698,6 +2715,8 @@ void test_fftn_3dfft_3dview(T atol = 1.0e-6) {
KokkosFFT::fft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::fftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::fftn(execution_space(), x, out_b, axes,
Expand All @@ -2711,6 +2730,7 @@ void test_fftn_3dfft_3dview(T atol = 1.0e-6) {
multiply(out_f, static_cast<T>(n0 * n1 * n2));

EXPECT_TRUE(allclose(out, out3, 1.e-5, atol));
EXPECT_TRUE(allclose(out_no_axes, out3, 1.e-5, atol));
EXPECT_TRUE(allclose(out_b, out3, 1.e-5, atol));
EXPECT_TRUE(allclose(out_o, out3, 1.e-5, atol));
EXPECT_TRUE(allclose(out_f, out3, 1.e-5, atol));
Expand All @@ -2726,7 +2746,7 @@ void test_ifftn_3dfft_3dview() {
ComplexView3DType out("out", n0, n1, n2), out1("out1", n0, n1, n2),
out2("out2", n0, n1, n2), out3("out3", n0, n1, n2);
ComplexView3DType out_b("out_b", n0, n1, n2), out_o("out_o", n0, n1, n2),
out_f("out_f", n0, n1, n2);
out_f("out_f", n0, n1, n2), out_no_axes("out_no_axes", n0, n1, n2);

const Kokkos::complex<T> I(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Expand All @@ -2746,6 +2766,8 @@ void test_ifftn_3dfft_3dview() {
KokkosFFT::ifft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/0);

KokkosFFT::ifftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
KokkosFFT::ifftn(execution_space(), x, out_b, axes,
Expand All @@ -2759,6 +2781,7 @@ void test_ifftn_3dfft_3dview() {
multiply(out_f, 1.0 / static_cast<T>(n0 * n1 * n2));

EXPECT_TRUE(allclose(out, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out3, 1.e-5, 1.e-6));
Expand All @@ -2776,7 +2799,8 @@ void test_rfftn_3dfft_3dview() {
out1("out1", n0, n1, n2 / 2 + 1), out2("out2", n0, n1, n2 / 2 + 1),
out3("out3", n0, n1, n2 / 2 + 1);
ComplexView3DType out_b("out_b", n0, n1, n2 / 2 + 1),
out_o("out_o", n0, n1, n2 / 2 + 1), out_f("out_f", n0, n1, n2 / 2 + 1);
out_o("out_o", n0, n1, n2 / 2 + 1), out_f("out_f", n0, n1, n2 / 2 + 1),
out_no_axes("out_no_axes", n0, n1, n2 / 2 + 1);

Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Kokkos::fill_random(x, random_pool, 1);
Expand All @@ -2795,6 +2819,10 @@ void test_rfftn_3dfft_3dview() {
KokkosFFT::fft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/0);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::rfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
Expand All @@ -2815,6 +2843,7 @@ void test_rfftn_3dfft_3dview() {
multiply(out_f, static_cast<T>(n0 * n1 * n2));

EXPECT_TRUE(allclose(out, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out3, 1.e-5, 1.e-6));
Expand All @@ -2833,7 +2862,7 @@ void test_irfftn_3dfft_3dview() {
out2("out2", n0, n1, n2 / 2 + 1);
RealView3DType out("out", n0, n1, n2), out3("out3", n0, n1, n2);
RealView3DType out_b("out_b", n0, n1, n2), out_o("out_o", n0, n1, n2),
out_f("out_f", n0, n1, n2);
out_f("out_f", n0, n1, n2), out_no_axes("out_no_axes", n0, n1, n2);

const Kokkos::complex<T> I(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Expand All @@ -2852,6 +2881,11 @@ void test_irfftn_3dfft_3dview() {
KokkosFFT::irfft(execution_space(), out2, out3,
KokkosFFT::Normalization::backward, /*axis=*/2);

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(
execution_space(), x,
out_no_axes); // default: KokkosFFT::Normalization::backward

Kokkos::deep_copy(x, x_ref);
KokkosFFT::irfftn(execution_space(), x, out,
axes); // default: KokkosFFT::Normalization::backward
Expand All @@ -2872,6 +2906,7 @@ void test_irfftn_3dfft_3dview() {
multiply(out_f, 1.0 / static_cast<T>(n0 * n1 * n2));

EXPECT_TRUE(allclose(out, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_no_axes, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_b, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_o, out3, 1.e-5, 1.e-6));
EXPECT_TRUE(allclose(out_f, out3, 1.e-5, 1.e-6));
Expand Down

0 comments on commit eaf9e85

Please sign in to comment.