Skip to content

Commit

Permalink
Fix bugs without heFFTe
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Sep 18, 2024
1 parent 569bd1a commit 66f2d6d
Showing 1 changed file with 69 additions and 65 deletions.
134 changes: 69 additions & 65 deletions Source/ablastr/fields/IntegratedGreenFunctionSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,81 +194,92 @@ computePhiIGF ( amrex::MultiFab const & rho,

// Prepare to perform global FFT
// Since there is 1 MPI rank per box, here each MPI rank obtains its local box and the associated boxid
// (when not using heFFTe, there is only one box and thus the local box is the same as the global box)
int local_boxid = amrex::ParallelDescriptor::MyProc(); // because of how we made the DistributionMapping
amrex::Box local_nodal_box = realspace_ba[local_boxid];
amrex::Box local_box(local_nodal_box.smallEnd(), local_nodal_box.bigEnd());
local_box.shift(-realspace_box.smallEnd()); // This simplifies the setup because the global lo is zero now
// Since we the domain decompostion is in the z-direction, setting up c_local_box is simple.
amrex::Box c_local_box = local_box;
c_local_box.setBig(0, local_box.length(0)/2+1);

// Allocate array in spectral space
using SpectralField = amrex::BaseFab< amrex::GpuComplex< amrex::Real > > ;
SpectralField tmp_rho_fft(c_local_box, 1, amrex::The_Device_Arena());
SpectralField tmp_G_fft(c_local_box, 1, amrex::The_Device_Arena());
tmp_rho_fft.shift(realspace_box.smallEnd());
tmp_G_fft.shift(realspace_box.smallEnd());

// Create FFT plans
BL_PROFILE_VAR_START(timer_plans);
if (local_boxid < realspace_ba.size()) {
// When not using heFFTe, there is only one box (the global box)
// It is taken care of my MPI rank 0 ; other ranks have no work (hence the if condition)

amrex::Box local_nodal_box = realspace_ba[local_boxid];
amrex::Box local_box(local_nodal_box.smallEnd(), local_nodal_box.bigEnd());
local_box.shift(-realspace_box.smallEnd()); // This simplifies the setup because the global lo is zero now
// Since we the domain decompostion is in the z-direction, setting up c_local_box is simple.
amrex::Box c_local_box = local_box;
c_local_box.setBig(0, local_box.length(0)/2+1);

// Allocate array in spectral space
using SpectralField = amrex::BaseFab< amrex::GpuComplex< amrex::Real > > ;
SpectralField tmp_rho_fft(c_local_box, 1, amrex::The_Device_Arena());
SpectralField tmp_G_fft(c_local_box, 1, amrex::The_Device_Arena());
tmp_rho_fft.shift(realspace_box.smallEnd());
tmp_G_fft.shift(realspace_box.smallEnd());

// Create FFT plans
BL_PROFILE_VAR_START(timer_plans);
#if !defined(ABLASTR_USE_HEFFTE)
const amrex::IntVect fft_size = realspace_ba[local_boxid].length();
ablastr::math::anyfft::FFTplan forward_plan_rho = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_rho[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>(tmp_rho_fft.dataPtr()),
ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM);
ablastr::math::anyfft::FFTplan forward_plan_G = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_G[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>(tmp_G_fft.dataPtr()),
ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM);
ablastr::math::anyfft::FFTplan backward_plan = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_G[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>( tmp_G_fft.dataPtr()),
ablastr::math::anyfft::direction::C2R, AMREX_SPACEDIM);
const amrex::IntVect fft_size = realspace_ba[local_boxid].length();
ablastr::math::anyfft::FFTplan forward_plan_rho = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_rho[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>(tmp_rho_fft.dataPtr()),
ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM);
ablastr::math::anyfft::FFTplan forward_plan_G = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_G[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>(tmp_G_fft.dataPtr()),
ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM);
ablastr::math::anyfft::FFTplan backward_plan = ablastr::math::anyfft::CreatePlan(
fft_size, tmp_G[local_boxid].dataPtr(),
reinterpret_cast<ablastr::math::anyfft::Complex*>( tmp_G_fft.dataPtr()),
ablastr::math::anyfft::direction::C2R, AMREX_SPACEDIM);
#elif defined(ABLASTR_USE_HEFFTE)
#if defined(AMREX_USE_CUDA)
heffte::fft3d_r2c<heffte::backend::cufft> fft
heffte::fft3d_r2c<heffte::backend::cufft> fft
#elif defined(AMREX_USE_HIP)
heffte::fft3d_r2c<heffte::backend::rocfft> fft
heffte::fft3d_r2c<heffte::backend::rocfft> fft
#else
heffte::fft3d_r2c<heffte::backend::fftw> fft
heffte::fft3d_r2c<heffte::backend::fftw> fft
#endif
({{local_box.smallEnd(0), local_box.smallEnd(1), local_box.smallEnd(2)},
{local_box.bigEnd(0), local_box.bigEnd(1), local_box.bigEnd(2)}},
{{c_local_box.smallEnd(0), c_local_box.smallEnd(1), c_local_box.smallEnd(2)},
{c_local_box.bigEnd(0), c_local_box.bigEnd(1), c_local_box.bigEnd(2)}},
0, amrex::ParallelDescriptor::Communicator());
using heffte_complex = typename heffte::fft_output<amrex::Real>::type;
heffte_complex* rho_fft_data = (heffte_complex*) tmp_rho_fft.dataPtr();
heffte_complex* G_fft_data = (heffte_complex*) tmp_G_fft.dataPtr();
({{local_box.smallEnd(0), local_box.smallEnd(1), local_box.smallEnd(2)},
{local_box.bigEnd(0), local_box.bigEnd(1), local_box.bigEnd(2)}},
{{c_local_box.smallEnd(0), c_local_box.smallEnd(1), c_local_box.smallEnd(2)},
{c_local_box.bigEnd(0), c_local_box.bigEnd(1), c_local_box.bigEnd(2)}},
0, amrex::ParallelDescriptor::Communicator());
using heffte_complex = typename heffte::fft_output<amrex::Real>::type;
heffte_complex* rho_fft_data = (heffte_complex*) tmp_rho_fft.dataPtr();
heffte_complex* G_fft_data = (heffte_complex*) tmp_G_fft.dataPtr();
#endif
BL_PROFILE_VAR_STOP(timer_plans);
BL_PROFILE_VAR_STOP(timer_plans);

// Perform forward FFTs
BL_PROFILE_VAR_START(timer_ffts);
// Perform forward FFTs
BL_PROFILE_VAR_START(timer_ffts);
#if !defined(ABLASTR_USE_HEFFTE)
ablastr::math::anyfft::Execute(forward_plan_rho);
ablastr::math::anyfft::Execute(forward_plan_G);
ablastr::math::anyfft::Execute(forward_plan_rho);
ablastr::math::anyfft::Execute(forward_plan_G);
#elif defined(ABLASTR_USE_HEFFTE)
fft.forward(tmp_rho[local_boxid].dataPtr(), rho_fft_data);
fft.forward(tmp_G[local_boxid].dataPtr(), G_fft_data);
fft.forward(tmp_rho[local_boxid].dataPtr(), rho_fft_data);
fft.forward(tmp_G[local_boxid].dataPtr(), G_fft_data);
#endif
BL_PROFILE_VAR_STOP(timer_ffts);
BL_PROFILE_VAR_STOP(timer_ffts);

// Multiply tmp_G_fft and tmp_rho_fft in spectral space
// Store the result in-place in Gtmp_G_fft, to save memory
tmp_G_fft.template mult<amrex::RunOn::Device>(tmp_rho_fft, 0, 0, 1);
amrex::Gpu::streamSynchronize();
// Multiply tmp_G_fft and tmp_rho_fft in spectral space
// Store the result in-place in Gtmp_G_fft, to save memory
tmp_G_fft.template mult<amrex::RunOn::Device>(tmp_rho_fft, 0, 0, 1);
amrex::Gpu::streamSynchronize();

// Perform backward FFT
BL_PROFILE_VAR_START(timer_ffts);
// Perform backward FFT
BL_PROFILE_VAR_START(timer_ffts);
#if !defined(ABLASTR_USE_HEFFTE)
ablastr::math::anyfft::Execute(backward_plan);
ablastr::math::anyfft::Execute(backward_plan);
#elif defined(ABLASTR_USE_HEFFTE)
fft.backward(G_fft_data, tmp_G[local_boxid].dataPtr());
fft.backward(G_fft_data, tmp_G[local_boxid].dataPtr());
#endif
BL_PROFILE_VAR_STOP(timer_ffts);
BL_PROFILE_VAR_STOP(timer_ffts);

#if !defined(ABLASTR_USE_HEFFTE)
// Loop to destroy FFT plans
ablastr::math::anyfft::DestroyPlan(forward_plan_G);
ablastr::math::anyfft::DestroyPlan(forward_plan_rho);
ablastr::math::anyfft::DestroyPlan(backward_plan);
#endif
}

// Normalize, since (FFT + inverse FFT) results in a factor N
const amrex::Real normalization = 1._rt / realspace_box.numPts();
Expand All @@ -279,13 +290,6 @@ computePhiIGF ( amrex::MultiFab const & rho,
phi.ParallelCopy( tmp_G, 0, 0, 1, amrex::IntVect::TheZeroVector(), phi.nGrowVect());
BL_PROFILE_VAR_STOP(timer_pcopies);

#if !defined(ABLASTR_USE_HEFFTE)
// Loop to destroy FFT plans
ablastr::math::anyfft::DestroyPlan(forward_plan_G);
ablastr::math::anyfft::DestroyPlan(forward_plan_rho);
ablastr::math::anyfft::DestroyPlan(backward_plan);
#endif

#endif // ABLASTR_USE_FFT
}
} // namespace ablastr::fields

0 comments on commit 66f2d6d

Please sign in to comment.