Skip to content

Commit

Permalink
Add oneMKL FFT to AnyFFT wrapper for SYCL backend
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Aug 14, 2024
1 parent 37046f3 commit bd43dd3
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Source/Make.WarpX
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ ifeq ($(USE_FFT),TRUE)
INCLUDE_LOCATIONS += $(ROC_PATH)/rocfft/include
LIBRARY_LOCATIONS += $(ROC_PATH)/rocfft/lib
libraries += -lrocfft
else ifeq ($(USE_SYCL),TRUE)
# nothing
else # Running on CPU
# Use FFTW
ifeq ($(PRECISION),FLOAT)
Expand Down
17 changes: 17 additions & 0 deletions Source/ablastr/math/fft/AnyFFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#ifdef ABLASTR_USE_FFT
# include <AMReX_Config.H>
# include <AMReX_GpuComplex.H>
# include <AMReX_LayoutData.H>

# if defined(AMREX_USE_CUDA)
Expand All @@ -20,6 +21,8 @@
# else
# include <rocfft.h>
# endif
# elif defined(AMREX_USE_SYCL)
# include <oneapi/mkl/dfti.hpp>
# else
# include <fftw3.h>
# endif
Expand Down Expand Up @@ -62,6 +65,8 @@ namespace ablastr::math::anyfft
# else
using Complex = double2;
# endif
# elif defined(AMREX_USE_SYCL)
using Complex = amrex::GpuComplex<amrex::Real>;
# else
# ifdef AMREX_USE_FLOAT
using Complex = fftwf_complex;
Expand All @@ -77,6 +82,15 @@ namespace ablastr::math::anyfft
using VendorFFTPlan = cufftHandle;
# elif defined(AMREX_USE_HIP)
using VendorFFTPlan = rocfft_plan;
# elif defined(AMREX_USE_SYCL)
using VendorFFTPlan = oneapi::mkl::dft::descriptor<
# ifdef AMREX_USE_FLOAT
oneapi::mkl::dft::precision::SINGLE,
# else
oneapi::mkl::dft::precision::DOUBLE,
# endif
oneapi::mkl::dft::domain::REAL> *;
// dft::descriptor has no default ctor, so we have to use ptr.
# else
# ifdef AMREX_USE_FLOAT
using VendorFFTPlan = fftwf_plan;
Expand All @@ -99,6 +113,9 @@ namespace ablastr::math::anyfft
VendorFFTPlan m_plan; /**< Vendor FFT plan */
direction m_dir; /**< direction (C2R or R2C) */
int m_dim; /**< Dimensionality of the FFT plan */
#ifdef AMREX_USE_SYCL
amrex::gpuStream_t m_stream;
#endif
};

/** Collection of FFT plans, one FFTplan per box */
Expand Down
2 changes: 2 additions & 0 deletions Source/ablastr/math/fft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ foreach(D IN LISTS WarpX_DIMS)
target_sources(ablastr_${SD} PRIVATE WrapCuFFT.cpp)
elseif(WarpX_COMPUTE STREQUAL HIP)
target_sources(ablastr_${SD} PRIVATE WrapRocFFT.cpp)
elseif(WarpX_COMPUTE STREQUAL SYCL)
target_sources(ablastr_${SD} PRIVATE WrapMklFFT.cpp)
else()
target_sources(ablastr_${SD} PRIVATE WrapFFTW.cpp)
endif()
Expand Down
2 changes: 2 additions & 0 deletions Source/ablastr/math/fft/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ ifeq ($(USE_FFT),TRUE)
CEXE_sources += WrapCuFFT.cpp
else ifeq ($(USE_HIP),TRUE)
CEXE_sources += WrapRocFFT.cpp
else ifeq ($(USE_SYCL),TRUE)
CEXE_sources += WrapMklFFT.cpp
else
CEXE_sources += WrapFFTW.cpp
endif
Expand Down
96 changes: 96 additions & 0 deletions Source/ablastr/math/fft/WrapMklFFT.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/* Copyright 2019-2023
*
* This file is part of ABLASTR.
*
* License: BSD-3-Clause-LBNL
*/

#include "AnyFFT.H"

#include "ablastr/utils/TextMsg.H"
#include "ablastr/profiler/ProfilerWrapper.H"

#include <cstdint>

namespace ablastr::math::anyfft
{

void setup(){/*nothing to do*/}

void cleanup(){/*nothing to do*/}

FFTplan CreatePlan(const amrex::IntVect& real_size, amrex::Real * const real_array,
Complex * const complex_array, const direction dir, const int dim)
{
FFTplan fft_plan;
ABLASTR_PROFILE("ablastr::math::anyfft::CreatePlan");

// Initialize fft_plan.m_plan with the vendor fft plan.
std::vector<std::int64_t> strides(dim+1);
if (dim == 3) {
fft_plan.m_plan = new std::remove_pointer_t<VendorFFTPlan>(
{std::int64_t(real_size[2]),
std::int64_t(real_size[1]),
std::int64_t(real_size[0])});
strides[0] = 0;
strides[1] = real_size[0] * real_size[1];
strides[2] = real_size[0];
strides[3] = 1;
} else if (dim == 2) {
fft_plan.m_plan = new std::remove_pointer_t<VendorFFTPlan>(
{std::int64_t(real_size[1]),
std::int64_t(real_size[0])});
strides[0] = 0;
strides[1] = real_size[0];
strides[2] = 1;
} else if (dim == 1) {
strides[0] = 0;
strides[1] = 1;
fft_plan.m_plan = new std::remove_pointer_t<VendorFFTPlan>(
std::int64_t(real_size[0]));
} else {
ABLASTR_ABORT_WITH_MESSAGE("only dim2 =1, dim=2 and dim=3 have been implemented");
}

fft_plan.m_plan->set_value(oneapi::mkl::dft::config_param::PLACEMENT,
DFTI_NOT_INPLACE);
fft_plan.m_plan->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES,
strides.data());
fft_plan.m_plan->commit(amrex::Gpu::Device::streamQueue());

// Store meta-data in fft_plan
fft_plan.m_real_array = real_array;
fft_plan.m_complex_array = complex_array;
fft_plan.m_dir = dir;
fft_plan.m_dim = dim;
fft_plan.m_stream = amrex::Gpu::gpuStream();

return fft_plan;
}

void DestroyPlan(FFTplan& fft_plan)
{
delete fft_plan.m_plan;
}

void Execute (FFTplan& fft_plan)
{
if (!(fft_plan.m_stream == amrex::Gpu::gpuStream())) {
amrex::Gpu::streamSynchronize();
}

sycl::event r;
if (fft_plan.m_dir == direction::R2C) {
r = oneapi::mkl::dft::compute_forward(
*fft_plan.m_plan,
fft_plan.m_real_array,
reinterpret_cast<std::complex<amrex::Real>*>(fft_plan.m_complex_array));
} else {
r = oneapi::mkl::dft::compute_backward(
*fft_plan.m_plan,
reinterpret_cast<std::complex<amrex::Real>*>(fft_plan.m_complex_array),
fft_plan.m_real_array);
}
r.wait();
}
}

0 comments on commit bd43dd3

Please sign in to comment.