Skip to content

Commit

Permalink
cleanup in-place tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Oct 17, 2024
1 parent 48947d1 commit de41209
Showing 1 changed file with 50 additions and 45 deletions.
95 changes: 50 additions & 45 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) {

for (int i = 1; i < maxlen; i++) {
ComplexView1DType a("a", i), a_ref("a_ref", i);
ComplexUView1DType out(a.data(), i), _a(a.data(), i);
ComplexUView1DType a_hat(a.data(), i), inv_a_hat(a.data(), i);

// Used for Real transforms
ComplexView1DType outr("outr", i / 2 + 1);
RealUView1DType ar(reinterpret_cast<T*>(outr.data()), i),
_ar(reinterpret_cast<T*>(outr.data()), i);
ComplexView1DType ar_hat("ar_hat", i / 2 + 1);
RealUView1DType ar(reinterpret_cast<T*>(ar_hat.data()), i),
inv_ar_hat(reinterpret_cast<T*>(ar_hat.data()), i);
RealView1DType ar_ref("ar_ref", i);

const Kokkos::complex<T> I(1.0, 1.0);
Expand All @@ -106,14 +106,14 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) {

Kokkos::fence();

KokkosFFT::fft(execution_space(), a, out);
KokkosFFT::ifft(execution_space(), out, _a);
KokkosFFT::fft(execution_space(), a, a_hat);
KokkosFFT::ifft(execution_space(), a_hat, inv_a_hat);

KokkosFFT::rfft(execution_space(), ar, outr);
KokkosFFT::irfft(execution_space(), outr, _ar);
KokkosFFT::rfft(execution_space(), ar, ar_hat);
KokkosFFT::irfft(execution_space(), ar_hat, inv_ar_hat);

EXPECT_TRUE(allclose(_a, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(_ar, ar_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol));
}
}

Expand Down Expand Up @@ -1683,78 +1683,83 @@ void test_fft2_2dfft_2dview_inplace([[maybe_unused]] T atol = 1.0e-12) {
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

ComplexView2DType x("x", n0, n1), x_ref("x_ref", n0, n1);
ComplexUView2DType xout(x.data(), n0, n1), _x(x.data(), n0, n1);
ComplexUView2DType x_hat(x.data(), n0, n1), inv_x_hat(x.data(), n0, n1);

// Used for real transforms
ComplexView2DType xr_out("xr_out", n0, n1 / 2 + 1),
xr_out_ref("xr_out_ref", n0, n1 / 2 + 1), xhat("xhat", n0, n1 / 2 + 1),
xhat_ref("xhat_ref", n0, n1 / 2 + 1);
RealView2DType xr_ref("xr_ref", n0, n1), _xr_unpadded("_xr_unpadded", n0, n1),
_xr_ref("_xr_ref", n0, n1);
RealUView2DType xr(reinterpret_cast<T*>(xr_out.data()), n0, n1),
xr_correct(reinterpret_cast<T*>(xr_out.data()), n0, (n1 / 2 + 1) * 2),
_xr(reinterpret_cast<T*>(xhat.data()), n0, n1),
_xr_correct(reinterpret_cast<T*>(xhat.data()), n0, (n1 / 2 + 1) * 2);

auto sub_xr =
Kokkos::subview(xr_correct, Kokkos::ALL(), Kokkos::pair<int, int>(0, n1));

const Kokkos::complex<T> I(1.0, 1.0);
ComplexView2DType xr_hat("xr_hat", n0, n1 / 2 + 1),
xr_hat_ref("xr_hat_ref", n0, n1 / 2 + 1);
RealView2DType xr_ref("xr_ref", n0, n1),
inv_xr_hat_unpadded("inv_xr_hat_unpadded", n0, n1),
inv_xr_hat_ref("inv_xr_hat_ref", n0, n1);

// Unmanged views for in-place transforms
RealUView2DType xr(reinterpret_cast<T*>(xr_hat.data()), n0, n1),
inv_xr_hat(reinterpret_cast<T*>(xr_hat.data()), n0, n1);
RealUView2DType xr_unpadded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2);

// Initialize xr_hat through xr_unpadded
auto sub_xr_unpadded = Kokkos::subview(xr_unpadded, Kokkos::ALL(),
Kokkos::pair<int, int>(0, n1));

const Kokkos::complex<T> z(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Kokkos::fill_random(xr_ref, random_pool, 1.0);
Kokkos::fill_random(x, random_pool, I);
Kokkos::fill_random(xhat, random_pool, I);
Kokkos::deep_copy(sub_xr, xr_ref);
Kokkos::fill_random(x, random_pool, z);
Kokkos::deep_copy(sub_xr_unpadded, xr_ref);
Kokkos::deep_copy(x_ref, x);
Kokkos::deep_copy(xhat_ref, xhat);

using axes_type = KokkosFFT::axis_type<2>;
axes_type axes = {-2, -1};

if constexpr (std::is_same_v<LayoutType, Kokkos::LayoutLeft>) {
// in-place transforms are not supported if transpose is needed
EXPECT_THROW(KokkosFFT::fft2(execution_space(), x, xout,
EXPECT_THROW(KokkosFFT::fft2(execution_space(), x, x_hat,
KokkosFFT::Normalization::backward, axes),
std::runtime_error);
EXPECT_THROW(KokkosFFT::ifft2(execution_space(), xout, _x,
EXPECT_THROW(KokkosFFT::ifft2(execution_space(), x_hat, inv_x_hat,
KokkosFFT::Normalization::backward, axes),
std::runtime_error);
EXPECT_THROW(KokkosFFT::rfft2(execution_space(), xr, xr_out,
EXPECT_THROW(KokkosFFT::rfft2(execution_space(), xr, xr_hat,
KokkosFFT::Normalization::backward, axes),
std::runtime_error);
EXPECT_THROW(KokkosFFT::irfft2(execution_space(), xhat, _xr,
EXPECT_THROW(KokkosFFT::irfft2(execution_space(), xr_hat, inv_xr_hat,
KokkosFFT::Normalization::backward, axes),
std::runtime_error);
} else {
// Identity tests for complex transforms
KokkosFFT::fft2(execution_space(), x, xout,
KokkosFFT::fft2(execution_space(), x, x_hat,
KokkosFFT::Normalization::backward, axes);
KokkosFFT::ifft2(execution_space(), xout, _x,
KokkosFFT::ifft2(execution_space(), x_hat, inv_x_hat,
KokkosFFT::Normalization::backward, axes);
EXPECT_TRUE(allclose(_x, x_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_x_hat, x_ref, 1.e-5, atol));

// In-place transforms
KokkosFFT::rfft2(execution_space(), xr, xr_out,
KokkosFFT::rfft2(execution_space(), xr, xr_hat,
KokkosFFT::Normalization::backward, axes);

// Out-of-place transforms (reference)
KokkosFFT::rfft2(execution_space(), xr_ref, xr_out_ref,
KokkosFFT::rfft2(execution_space(), xr_ref, xr_hat_ref,
KokkosFFT::Normalization::backward, axes);
EXPECT_TRUE(allclose(xr_out, xr_out_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(xr_hat, xr_hat_ref, 1.e-5, atol));

// In-place transforms
KokkosFFT::irfft2(execution_space(), xhat, _xr,
Kokkos::fill_random(xr_hat, random_pool, z);
Kokkos::deep_copy(xr_hat_ref, xr_hat);
KokkosFFT::irfft2(execution_space(), xr_hat, inv_xr_hat,
KokkosFFT::Normalization::backward, axes);

// Out-of-place transforms (reference)
KokkosFFT::irfft2(execution_space(), xhat_ref, _xr_ref,
KokkosFFT::irfft2(execution_space(), xr_hat_ref, inv_xr_hat_ref,
KokkosFFT::Normalization::backward, axes);

auto sub_xr_correct = Kokkos::subview(_xr_correct, Kokkos::ALL(),
Kokkos::pair<int, int>(0, n1));
Kokkos::deep_copy(_xr_unpadded, sub_xr_correct);
RealUView2DType inv_xr_hat_padded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2);
auto sub_inv_xr_hat_padded = Kokkos::subview(
inv_xr_hat_padded, Kokkos::ALL(), Kokkos::pair<int, int>(0, n1));
Kokkos::deep_copy(inv_xr_hat_unpadded, sub_inv_xr_hat_padded);

EXPECT_TRUE(allclose(_xr_unpadded, _xr_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_xr_hat_unpadded, inv_xr_hat_ref, 1.e-5, atol));
}
}

Expand Down

0 comments on commit de41209

Please sign in to comment.