diff --git a/Source/ablastr/fields/IntegratedGreenFunctionSolver.cpp b/Source/ablastr/fields/IntegratedGreenFunctionSolver.cpp index 850919f8451..5cf01f0b053 100644 --- a/Source/ablastr/fields/IntegratedGreenFunctionSolver.cpp +++ b/Source/ablastr/fields/IntegratedGreenFunctionSolver.cpp @@ -194,81 +194,92 @@ computePhiIGF ( amrex::MultiFab const & rho, // Prepare to perform global FFT // Since there is 1 MPI rank per box, here each MPI rank obtains its local box and the associated boxid - // (when not using heFFTe, there is only one box and thus the local box is the same as the global box) int local_boxid = amrex::ParallelDescriptor::MyProc(); // because of how we made the DistributionMapping - amrex::Box local_nodal_box = realspace_ba[local_boxid]; - amrex::Box local_box(local_nodal_box.smallEnd(), local_nodal_box.bigEnd()); - local_box.shift(-realspace_box.smallEnd()); // This simplifies the setup because the global lo is zero now - // Since we the domain decompostion is in the z-direction, setting up c_local_box is simple. - amrex::Box c_local_box = local_box; - c_local_box.setBig(0, local_box.length(0)/2+1); - - // Allocate array in spectral space - using SpectralField = amrex::BaseFab< amrex::GpuComplex< amrex::Real > > ; - SpectralField tmp_rho_fft(c_local_box, 1, amrex::The_Device_Arena()); - SpectralField tmp_G_fft(c_local_box, 1, amrex::The_Device_Arena()); - tmp_rho_fft.shift(realspace_box.smallEnd()); - tmp_G_fft.shift(realspace_box.smallEnd()); - - // Create FFT plans - BL_PROFILE_VAR_START(timer_plans); + if (local_boxid < realspace_ba.size()) { + // When not using heFFTe, there is only one box (the global box) + // It is taken care of my MPI rank 0 ; other ranks have no work (hence the if condition) + + amrex::Box local_nodal_box = realspace_ba[local_boxid]; + amrex::Box local_box(local_nodal_box.smallEnd(), local_nodal_box.bigEnd()); + local_box.shift(-realspace_box.smallEnd()); // This simplifies the setup because the global lo is zero now + // Since we the domain decompostion is in the z-direction, setting up c_local_box is simple. + amrex::Box c_local_box = local_box; + c_local_box.setBig(0, local_box.length(0)/2+1); + + // Allocate array in spectral space + using SpectralField = amrex::BaseFab< amrex::GpuComplex< amrex::Real > > ; + SpectralField tmp_rho_fft(c_local_box, 1, amrex::The_Device_Arena()); + SpectralField tmp_G_fft(c_local_box, 1, amrex::The_Device_Arena()); + tmp_rho_fft.shift(realspace_box.smallEnd()); + tmp_G_fft.shift(realspace_box.smallEnd()); + + // Create FFT plans + BL_PROFILE_VAR_START(timer_plans); #if !defined(ABLASTR_USE_HEFFTE) - const amrex::IntVect fft_size = realspace_ba[local_boxid].length(); - ablastr::math::anyfft::FFTplan forward_plan_rho = ablastr::math::anyfft::CreatePlan( - fft_size, tmp_rho[local_boxid].dataPtr(), - reinterpret_cast(tmp_rho_fft.dataPtr()), - ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM); - ablastr::math::anyfft::FFTplan forward_plan_G = ablastr::math::anyfft::CreatePlan( - fft_size, tmp_G[local_boxid].dataPtr(), - reinterpret_cast(tmp_G_fft.dataPtr()), - ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM); - ablastr::math::anyfft::FFTplan backward_plan = ablastr::math::anyfft::CreatePlan( - fft_size, tmp_G[local_boxid].dataPtr(), - reinterpret_cast( tmp_G_fft.dataPtr()), - ablastr::math::anyfft::direction::C2R, AMREX_SPACEDIM); + const amrex::IntVect fft_size = realspace_ba[local_boxid].length(); + ablastr::math::anyfft::FFTplan forward_plan_rho = ablastr::math::anyfft::CreatePlan( + fft_size, tmp_rho[local_boxid].dataPtr(), + reinterpret_cast(tmp_rho_fft.dataPtr()), + ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM); + ablastr::math::anyfft::FFTplan forward_plan_G = ablastr::math::anyfft::CreatePlan( + fft_size, tmp_G[local_boxid].dataPtr(), + reinterpret_cast(tmp_G_fft.dataPtr()), + ablastr::math::anyfft::direction::R2C, AMREX_SPACEDIM); + ablastr::math::anyfft::FFTplan backward_plan = ablastr::math::anyfft::CreatePlan( + fft_size, tmp_G[local_boxid].dataPtr(), + reinterpret_cast( tmp_G_fft.dataPtr()), + ablastr::math::anyfft::direction::C2R, AMREX_SPACEDIM); #elif defined(ABLASTR_USE_HEFFTE) #if defined(AMREX_USE_CUDA) - heffte::fft3d_r2c fft + heffte::fft3d_r2c fft #elif defined(AMREX_USE_HIP) - heffte::fft3d_r2c fft + heffte::fft3d_r2c fft #else - heffte::fft3d_r2c fft + heffte::fft3d_r2c fft #endif - ({{local_box.smallEnd(0), local_box.smallEnd(1), local_box.smallEnd(2)}, - {local_box.bigEnd(0), local_box.bigEnd(1), local_box.bigEnd(2)}}, - {{c_local_box.smallEnd(0), c_local_box.smallEnd(1), c_local_box.smallEnd(2)}, - {c_local_box.bigEnd(0), c_local_box.bigEnd(1), c_local_box.bigEnd(2)}}, - 0, amrex::ParallelDescriptor::Communicator()); - using heffte_complex = typename heffte::fft_output::type; - heffte_complex* rho_fft_data = (heffte_complex*) tmp_rho_fft.dataPtr(); - heffte_complex* G_fft_data = (heffte_complex*) tmp_G_fft.dataPtr(); + ({{local_box.smallEnd(0), local_box.smallEnd(1), local_box.smallEnd(2)}, + {local_box.bigEnd(0), local_box.bigEnd(1), local_box.bigEnd(2)}}, + {{c_local_box.smallEnd(0), c_local_box.smallEnd(1), c_local_box.smallEnd(2)}, + {c_local_box.bigEnd(0), c_local_box.bigEnd(1), c_local_box.bigEnd(2)}}, + 0, amrex::ParallelDescriptor::Communicator()); + using heffte_complex = typename heffte::fft_output::type; + heffte_complex* rho_fft_data = (heffte_complex*) tmp_rho_fft.dataPtr(); + heffte_complex* G_fft_data = (heffte_complex*) tmp_G_fft.dataPtr(); #endif - BL_PROFILE_VAR_STOP(timer_plans); + BL_PROFILE_VAR_STOP(timer_plans); - // Perform forward FFTs - BL_PROFILE_VAR_START(timer_ffts); + // Perform forward FFTs + BL_PROFILE_VAR_START(timer_ffts); #if !defined(ABLASTR_USE_HEFFTE) - ablastr::math::anyfft::Execute(forward_plan_rho); - ablastr::math::anyfft::Execute(forward_plan_G); + ablastr::math::anyfft::Execute(forward_plan_rho); + ablastr::math::anyfft::Execute(forward_plan_G); #elif defined(ABLASTR_USE_HEFFTE) - fft.forward(tmp_rho[local_boxid].dataPtr(), rho_fft_data); - fft.forward(tmp_G[local_boxid].dataPtr(), G_fft_data); + fft.forward(tmp_rho[local_boxid].dataPtr(), rho_fft_data); + fft.forward(tmp_G[local_boxid].dataPtr(), G_fft_data); #endif - BL_PROFILE_VAR_STOP(timer_ffts); + BL_PROFILE_VAR_STOP(timer_ffts); - // Multiply tmp_G_fft and tmp_rho_fft in spectral space - // Store the result in-place in Gtmp_G_fft, to save memory - tmp_G_fft.template mult(tmp_rho_fft, 0, 0, 1); - amrex::Gpu::streamSynchronize(); + // Multiply tmp_G_fft and tmp_rho_fft in spectral space + // Store the result in-place in Gtmp_G_fft, to save memory + tmp_G_fft.template mult(tmp_rho_fft, 0, 0, 1); + amrex::Gpu::streamSynchronize(); - // Perform backward FFT - BL_PROFILE_VAR_START(timer_ffts); + // Perform backward FFT + BL_PROFILE_VAR_START(timer_ffts); #if !defined(ABLASTR_USE_HEFFTE) - ablastr::math::anyfft::Execute(backward_plan); + ablastr::math::anyfft::Execute(backward_plan); #elif defined(ABLASTR_USE_HEFFTE) - fft.backward(G_fft_data, tmp_G[local_boxid].dataPtr()); + fft.backward(G_fft_data, tmp_G[local_boxid].dataPtr()); #endif - BL_PROFILE_VAR_STOP(timer_ffts); + BL_PROFILE_VAR_STOP(timer_ffts); + +#if !defined(ABLASTR_USE_HEFFTE) + // Loop to destroy FFT plans + ablastr::math::anyfft::DestroyPlan(forward_plan_G); + ablastr::math::anyfft::DestroyPlan(forward_plan_rho); + ablastr::math::anyfft::DestroyPlan(backward_plan); +#endif + } // Normalize, since (FFT + inverse FFT) results in a factor N const amrex::Real normalization = 1._rt / realspace_box.numPts(); @@ -279,13 +290,6 @@ computePhiIGF ( amrex::MultiFab const & rho, phi.ParallelCopy( tmp_G, 0, 0, 1, amrex::IntVect::TheZeroVector(), phi.nGrowVect()); BL_PROFILE_VAR_STOP(timer_pcopies); -#if !defined(ABLASTR_USE_HEFFTE) - // Loop to destroy FFT plans - ablastr::math::anyfft::DestroyPlan(forward_plan_G); - ablastr::math::anyfft::DestroyPlan(forward_plan_rho); - ablastr::math::anyfft::DestroyPlan(backward_plan); -#endif - #endif // ABLASTR_USE_FFT } } // namespace ablastr::fields