Skip to content

Commit

Permalink
FFT: OpenBC Solver
Browse files Browse the repository at this point in the history
This implements the algorithm of Hockney, Methods Comp. Phys. 9 (1970),
136-210 for solving Poisson's equation with open boundaries.
  • Loading branch information
WeiqunZhang committed Nov 7, 2024
1 parent f46a4e5 commit d4c01c3
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 22 deletions.
3 changes: 2 additions & 1 deletion Src/Base/AMReX_BoxArray.H
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ namespace amrex
*/
[[nodiscard]] BoxArray decompose (Box const& domain, int nboxes,
Array<bool,AMREX_SPACEDIM> const& decomp
= {AMREX_D_DECL(true,true,true)});
= {AMREX_D_DECL(true,true,true)},
bool no_overlap = false);

struct BARef
{
Expand Down
21 changes: 18 additions & 3 deletions Src/Base/AMReX_BoxArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,7 @@ bool match (const BoxArray& x, const BoxArray& y)
}

BoxArray decompose (Box const& domain, int nboxes,
Array<bool,AMREX_SPACEDIM> const& decomp)
Array<bool,AMREX_SPACEDIM> const& decomp, bool no_overlap)
{
auto ndecomp = std::count(decomp.begin(), decomp.end(), true);

Expand Down Expand Up @@ -2048,9 +2048,24 @@ BoxArray decompose (Box const& domain, int nboxes,
ilo += domlo[0];
ihi += domlo[0];
Box b{IntVect(AMREX_D_DECL(ilo,jlo,klo)),
IntVect(AMREX_D_DECL(ihi,jhi,khi))};
IntVect(AMREX_D_DECL(ihi,jhi,khi)), ixtyp};
if (b.ok()) {
bl.push_back(b.convert(ixtyp));
if (no_overlap) {
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (ixtyp.nodeCentered() &&
b.bigEnd(idim) == ccdomain.bigEnd(idim))
{
b.growHi(idim, 1);
}
}
} else {
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (ixtyp.nodeCentered()) {
b.growHi(idim, 1);
}
}
}
bl.push_back(b);
}
AMREX_D_TERM(},},})

Expand Down
1 change: 1 addition & 0 deletions Src/FFT/AMReX_FFT.H
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define AMREX_FFT_H_
#include <AMReX_Config.H>

#include <AMReX_FFT_OpenBCSolver.H>
#include <AMReX_FFT_R2C.H>
#include <AMReX_FFT_R2X.H>

Expand Down
142 changes: 142 additions & 0 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#ifndef AMREX_FFT_OPENBC_SOlVER_H_
#define AMREX_FFT_OPENBC_SOlVER_H_

#include <AMReX_FFT_R2C.H>

#include <AMReX_VisMF.H>

namespace amrex::FFT
{

template <typename T = Real>
class OpenBCSolver
{
public:
using MF = typename R2C<T>::MF;
using cMF = typename R2C<T>::cMF;

explicit OpenBCSolver (Box const& domain);

template <class F>
void setGreensFunction (F const& greens_function);

void solve (MF& phi, MF const& rho);

[[nodiscard]] Box const& Domain () const { return m_domain; }

private:
Box m_domain;
R2C<T> m_r2c;
cMF m_G_fft;
};

template <typename T>
OpenBCSolver<T>::OpenBCSolver (Box const& domain)
: m_domain(domain),
m_r2c(Box(domain.smallEnd(), domain.bigEnd()+domain.length(), domain.ixType()))
{
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
m_G_fft.define(sd->boxArray(), sd->DistributionMap(), 1, 0);
}

template <typename T>
template <class F>
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
auto* infab = detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
auto const& lo3 = lo.dim3();
auto const& len = m_domain.length3d();
if (infab) {
auto const& a = infab->array();
auto box = infab->box();
GpuArray<int,3> nimages{1,1,1};
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
if (box.smallEnd(idim) == lo[idim] && box.length(idim) == 2*len[idim]) {
box.growHi(idim, -len[idim]+1); // +1 to include the middle plane
nimages[idim] = 2;
}
}
AMREX_ASSERT(nimages[0] == 2);
box.shift(-lo);
amrex::ParallelFor(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
{
if (i == len[0] || j == len[1] || k == len[2]) {
a(i+lo3.x,j+lo3.y,k+lo3.z) = T(0);
} else {
auto ii = i;
auto jj = (j > len[1]) ? 2*len[1]-j : j;
auto kk = (k > len[2]) ? 2*len[2]-k : k;
auto G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
for (int koff = 0; koff < nimages[2]; ++koff) {
int k2 = (koff == 0) ? k : 2*len[2]-k;
if (k2 == 2*len[2]) { continue; }
for (int joff = 0; joff < nimages[1]; ++joff) {
int j2 = (joff == 0) ? j : 2*len[1]-j;
if (j2 == 2*len[1]) { continue; }
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
int i2 = (ioff == 0) ? i : 2*len[0]-i;
if (i2 == 2*len[0]) { continue; }
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
}
}
}
}
});
}

m_r2c.forward(m_r2c.m_rx);

auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto const* srcfab = detail::get_fab(*sd);
if (srcfab) {
auto* dstfab = detail::get_fab(m_G_fft);
if (dstfab) {
#if defined(AMREX_USE_GPU)
Gpu::dtod_memcpy_async
#else
std::memcpy
#endif
(dstfab->dataPtr(), srcfab->dataPtr(), dstfab->nBytes());
} else {
amrex::Abort("FFT::OpenBCSolver: how did this happen");
}
}
}

template <typename T>
void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
{
auto& inmf = m_r2c.m_rx;
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);

m_r2c.forward(inmf);

auto scaling_factor = T(1) / T(m_r2c.m_real_domain.numPts());

auto const* gfab = detail::get_fab(m_G_fft);
if (gfab) {
auto [sd, ord] = m_r2c.getSpectralData();
amrex::ignore_unused(ord);
auto* rhofab = detail::get_fab(*sd);
if (rhofab) {
auto* pdst = rhofab->dataPtr();
auto const* psrc = gfab->dataPtr();
amrex::ParallelFor(rhofab->box().numPts(), [=] AMREX_GPU_DEVICE (Long i)
{
pdst[i] *= psrc[i] * scaling_factor;
});
} else {
amrex::Abort("FFT::OpenBCSolver::solve: how did this happen?");
}
}

m_r2c.backward_doit(phi, phi.nGrowVect());
}

}

#endif
80 changes: 79 additions & 1 deletion Src/FFT/AMReX_FFT_Poisson.H
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace amrex::FFT
{

/**
* \brief Poisson solver for all periodic boundaries using FFT
* \brief Poisson solver for periodic, Dirichlet & Neumann boundaries using
* FFT.
*/
template <typename MF = MultiFab>
class Poisson
Expand Down Expand Up @@ -40,6 +41,32 @@ private:
R2X<typename MF::value_type> m_r2x;
};

#if (AMREX_SPACEDIM == 3)
/**
* \brief Poisson solve for Open BC using FFT.
*/
template <typename MF = MultiFab>
class PoissonOpenBC
{
public:

template <typename FA=MF, std::enable_if_t<IsFabArray_v<FA>,int> = 0>
explicit PoissonOpenBC (Geometry const& geom,
IndexType ixtype = IndexType::TheCellType(),
IntVect const& ngrow = IntVect(0));

void solve (MF& soln, MF const& rhs);

void define_doit (); // has to be public for cuda

private:
Geometry m_geom;
Box m_grown_domain;
IntVect m_ngrow;
OpenBCSolver<typename MF::value_type> m_solver;
};
#endif

/**
* \brief 3D Poisson solver for periodic boundaries in the first two
* dimensions and Neumann in the last dimension.
Expand Down Expand Up @@ -123,6 +150,57 @@ void Poisson<MF>::solve (MF& soln, MF const& rhs)
});
}

#if (AMREX_SPACEDIM == 3)

template <typename MF>
template <typename FA, std::enable_if_t<IsFabArray_v<FA>,int> FOO>
PoissonOpenBC<MF>::PoissonOpenBC (Geometry const& geom, IndexType ixtype,
IntVect const& ngrow)
: m_geom(geom),
m_grown_domain(amrex::grow(amrex::convert(geom.Domain(),ixtype),ngrow)),
m_ngrow(ngrow),
m_solver(m_grown_domain)
{
define_doit();
}

template <typename MF>
void PoissonOpenBC<MF>::define_doit ()
{
using T = typename MF::value_type;
auto const& lo = m_grown_domain.smallEnd();
auto const dx = T(m_geom.CellSize(0));
auto const dy = T(m_geom.CellSize(1));
auto const dz = T(m_geom.CellSize(2));
auto const gfac = T(1)/T(std::sqrt(T(12)));
// 0.125 comes from that there are 8 Gauss quadrature points
auto const fac = T(-0.125) * (dx*dy*dz) / (T(4)*Math::pi<T>());
m_solver.setGreensFunction([=] AMREX_GPU_DEVICE (int i, int j, int k) -> T
{
auto x = (T(i-lo[0]) - gfac) * dx; // first Gauss quadrature point
auto y = (T(j-lo[1]) - gfac) * dy;
auto z = (T(k-lo[2]) - gfac) * dz;
T r = 0;
for (int gx = 0; gx < 2; ++gx) {
for (int gy = 0; gy < 2; ++gy) {
for (int gz = 0; gz < 2; ++gz) {
auto xg = x + 2*gx*gfac*dx;
auto yg = y + 2*gy*gfac*dy;
auto zg = z + 2*gz*gfac*dz;
r += T(1)/std::sqrt(xg*xg+yg*yg+zg*zg);
}}}
return fac * r;
});
}

template <typename MF>
void PoissonOpenBC<MF>::solve (MF& soln, MF const& rhs)
{
m_solver.solve(soln, rhs);
}

#endif /* AMREX_SPACEDIM == 3 */

template <typename MF>
void PoissonHybrid<MF>::solve (MF& soln, MF const& rhs)
{
Expand Down
Loading

0 comments on commit d4c01c3

Please sign in to comment.