Skip to content

Commit

Permalink
Cleanup stencils for filtering (#4985)
Browse files Browse the repository at this point in the history
* Clean up the FDTD stencils, simplifying the dimension macros

* Small fix

* Further clean up of macros

* Clean up of macros in BilinearFilter

* Fix in 2D for NCIGodfreyFilter

* Removed AMREX_SPACEDIM

* Some clean up

* Bug fix

* Clean up of DoFilter for GPU

* Revert previous commit

* Change (x,y,z) to (0,1,2) to be more general
  • Loading branch information
dpgrote authored Sep 13, 2024
1 parent 73e1f84 commit 32737be
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 119 deletions.
33 changes: 12 additions & 21 deletions Source/Filter/BilinearFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,34 +64,25 @@ void BilinearFilter::ComputeStencils(){
WARPX_PROFILE("BilinearFilter::ComputeStencils()");
int i = 0;
for (const auto& el : npass_each_dir ) {
stencil_length_each_dir[i++] = static_cast<int>(el);
stencil_length_each_dir[i++] = static_cast<int>(el) + 1;
}
stencil_length_each_dir += 1.;

m_stencil_0.resize( 1u + npass_each_dir[0] );
compute_stencil(m_stencil_0, npass_each_dir[0]);
#if defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ) || defined(WARPX_DIM_3D)
m_stencil_1.resize( 1u + npass_each_dir[1] );
compute_stencil(m_stencil_1, npass_each_dir[1]);
#endif
#if defined(WARPX_DIM_3D)
// npass_each_dir = npass_x npass_y npass_z
stencil_x.resize( 1u + npass_each_dir[0] );
stencil_y.resize( 1u + npass_each_dir[1] );
stencil_z.resize( 1u + npass_each_dir[2] );
compute_stencil(stencil_x, npass_each_dir[0]);
compute_stencil(stencil_y, npass_each_dir[1]);
compute_stencil(stencil_z, npass_each_dir[2]);
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
// npass_each_dir = npass_x npass_z
stencil_x.resize( 1u + npass_each_dir[0] );
stencil_z.resize( 1u + npass_each_dir[1] );
compute_stencil(stencil_x, npass_each_dir[0]);
compute_stencil(stencil_z, npass_each_dir[1]);
#elif defined(WARPX_DIM_1D_Z)
// npass_each_dir = npass_z
stencil_z.resize( 1u + npass_each_dir[0] );
compute_stencil(stencil_z, npass_each_dir[0]);
m_stencil_2.resize( 1u + npass_each_dir[2] );
compute_stencil(m_stencil_2, npass_each_dir[2]);
#endif

slen = stencil_length_each_dir.dim3();
#if defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
#if defined(WARPX_DIM_1D_Z) || defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
slen.z = 1;
#endif
#if defined(WARPX_DIM_1D_Z)
slen.y = 1;
slen.z = 1;
#endif
}
22 changes: 10 additions & 12 deletions Source/Filter/Filter.H
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,28 @@ public:

// Apply stencil on MultiFab.
// Guard cells are handled inside this function
void ApplyStencil(amrex::MultiFab& dstmf,
const amrex::MultiFab& srcmf, int lev, int scomp=0,
int dcomp=0, int ncomp=10000);
void ApplyStencil (amrex::MultiFab& dstmf,
const amrex::MultiFab& srcmf, int lev, int scomp=0,
int dcomp=0, int ncomp=10000);

// Apply stencil on a FabArray.
void ApplyStencil (amrex::FArrayBox& dstfab,
const amrex::FArrayBox& srcfab, const amrex::Box& tbx,
int scomp=0, int dcomp=0, int ncomp=10000);

// public for cuda
void DoFilter(const amrex::Box& tbx,
amrex::Array4<amrex::Real const> const& tmp,
amrex::Array4<amrex::Real > const& dst,
int scomp, int dcomp, int ncomp);
void DoFilter (const amrex::Box& tbx,
amrex::Array4<amrex::Real const> const& tmp,
amrex::Array4<amrex::Real > const& dst,
int scomp, int dcomp, int ncomp);

// In 2D, stencil_length_each_dir = {length(stencil_x), length(stencil_z)}
// Length of stencil in each included direction
amrex::IntVect stencil_length_each_dir;

protected:
// Stencil along each direction.
// in 2D, stencil_y is not initialized.
amrex::Gpu::DeviceVector<amrex::Real> stencil_x, stencil_y, stencil_z;
// Length of each stencil.
// In 2D, slen = {length(stencil_x), length(stencil_z), 1}
amrex::Gpu::DeviceVector<amrex::Real> m_stencil_0, m_stencil_1, m_stencil_2;
// Length of each stencil, 1 for dimensions not included
amrex::Dim3 slen;

private:
Expand Down
138 changes: 61 additions & 77 deletions Source/Filter/Filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,21 @@ Filter::ApplyStencil (FArrayBox& dstfab, const FArrayBox& srcfab,
DoFilter(tbx, src, dst, scomp, dcomp, ncomp);
}

/* \brief Apply stencil (2D/3D, CPU/GPU)
/* \brief Apply stencil (CPU/GPU)
*/
void Filter::DoFilter (const Box& tbx,
Array4<Real const> const& src,
Array4<Real > const& dst,
int scomp, int dcomp, int ncomp)
{
#if (AMREX_SPACEDIM >= 2)
amrex::Real const* AMREX_RESTRICT sx = stencil_x.data();
#endif
#if defined(WARPX_DIM_3D)
amrex::Real const* AMREX_RESTRICT sy = stencil_y.data();
#endif
amrex::Real const* AMREX_RESTRICT sz = stencil_z.data();
AMREX_D_TERM(
amrex::Real const* AMREX_RESTRICT s0 = m_stencil_0.data();,
amrex::Real const* AMREX_RESTRICT s1 = m_stencil_1.data();,
amrex::Real const* AMREX_RESTRICT s2 = m_stencil_2.data();
)
Dim3 slen_local = slen;

#if defined(WARPX_DIM_3D)
#if AMREX_SPACEDIM == 3
AMREX_PARALLEL_FOR_4D ( tbx, ncomp, i, j, k, n,
{
Real d = 0.0;
Expand All @@ -115,25 +113,25 @@ void Filter::DoFilter (const Box& tbx,
return src.contains(jj,kk,ll) ? src(jj,kk,ll,nn) : 0.0_rt;
};

for (int iz=0; iz < slen_local.z; ++iz){
for (int iy=0; iy < slen_local.y; ++iy){
for (int ix=0; ix < slen_local.x; ++ix){
Real sss = sx[ix]*sy[iy]*sz[iz];
d += sss*( src_zeropad(i-ix,j-iy,k-iz,scomp+n)
+src_zeropad(i+ix,j-iy,k-iz,scomp+n)
+src_zeropad(i-ix,j+iy,k-iz,scomp+n)
+src_zeropad(i+ix,j+iy,k-iz,scomp+n)
+src_zeropad(i-ix,j-iy,k+iz,scomp+n)
+src_zeropad(i+ix,j-iy,k+iz,scomp+n)
+src_zeropad(i-ix,j+iy,k+iz,scomp+n)
+src_zeropad(i+ix,j+iy,k+iz,scomp+n));
for (int i2=0; i2 < slen_local.z; ++i2){
for (int i1=0; i1 < slen_local.y; ++i1){
for (int i0=0; i0 < slen_local.x; ++i0){
Real sss = s0[i0]*s1[i1]*s2[i2];
d += sss*( src_zeropad(i-i0,j-i1,k-i2,scomp+n)
+src_zeropad(i+i0,j-i1,k-i2,scomp+n)
+src_zeropad(i-i0,j+i1,k-i2,scomp+n)
+src_zeropad(i+i0,j+i1,k-i2,scomp+n)
+src_zeropad(i-i0,j-i1,k+i2,scomp+n)
+src_zeropad(i+i0,j-i1,k+i2,scomp+n)
+src_zeropad(i-i0,j+i1,k+i2,scomp+n)
+src_zeropad(i+i0,j+i1,k+i2,scomp+n));
}
}
}

dst(i,j,k,dcomp+n) = d;
});
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
#elif AMREX_SPACEDIM == 2
AMREX_PARALLEL_FOR_4D ( tbx, ncomp, i, j, k, n,
{
Real d = 0.0;
Expand All @@ -145,21 +143,21 @@ void Filter::DoFilter (const Box& tbx,
return src.contains(jj,kk,ll) ? src(jj,kk,ll,nn) : 0.0_rt;
};

for (int iz=0; iz < slen_local.z; ++iz){
for (int iy=0; iy < slen_local.y; ++iy){
for (int ix=0; ix < slen_local.x; ++ix){
Real sss = sx[ix]*sz[iy];
d += sss*( src_zeropad(i-ix,j-iy,k,scomp+n)
+src_zeropad(i+ix,j-iy,k,scomp+n)
+src_zeropad(i-ix,j+iy,k,scomp+n)
+src_zeropad(i+ix,j+iy,k,scomp+n));
for (int i2=0; i2 < slen_local.z; ++i2){
for (int i1=0; i1 < slen_local.y; ++i1){
for (int i0=0; i0 < slen_local.x; ++i0){
Real sss = s0[i0]*s1[i1];
d += sss*( src_zeropad(i-i0,j-i1,k,scomp+n)
+src_zeropad(i+i0,j-i1,k,scomp+n)
+src_zeropad(i-i0,j+i1,k,scomp+n)
+src_zeropad(i+i0,j+i1,k,scomp+n));
}
}
}

dst(i,j,k,dcomp+n) = d;
});
#elif defined(WARPX_DIM_1D_Z)
#elif AMREX_SPACEDIM == 1
AMREX_PARALLEL_FOR_4D ( tbx, ncomp, i, j, k, n,
{
Real d = 0.0;
Expand All @@ -171,21 +169,18 @@ void Filter::DoFilter (const Box& tbx,
return src.contains(jj,kk,ll) ? src(jj,kk,ll,nn) : 0.0_rt;
};

for (int iz=0; iz < slen_local.z; ++iz){
for (int iy=0; iy < slen_local.y; ++iy){
for (int ix=0; ix < slen_local.x; ++ix){
Real sss = sz[iy];
d += sss*( src_zeropad(i-ix,j,k,scomp+n)
+src_zeropad(i+ix,j,k,scomp+n));
for (int i2=0; i2 < slen_local.z; ++i2){
for (int i1=0; i1 < slen_local.y; ++i1){
for (int i0=0; i0 < slen_local.x; ++i0){
Real sss = s0[i0];
d += sss*( src_zeropad(i-i0,j,k,scomp+n)
+src_zeropad(i+i0,j,k,scomp+n));
}
}
}

dst(i,j,k,dcomp+n) = d;
});
#else
WARPX_ABORT_WITH_MESSAGE(
"Filter not implemented for the current geometry!");
#endif
}

Expand Down Expand Up @@ -278,13 +273,11 @@ void Filter::DoFilter (const Box& tbx,
const auto lo = amrex::lbound(tbx);
const auto hi = amrex::ubound(tbx);
// tmp and dst are of type Array4 (Fortran ordering)
#if (AMREX_SPACEDIM >= 2)
amrex::Real const* AMREX_RESTRICT sx = stencil_x.data();
#endif
#if defined(WARPX_DIM_3D)
amrex::Real const* AMREX_RESTRICT sy = stencil_y.data();
#endif
amrex::Real const* AMREX_RESTRICT sz = stencil_z.data();
AMREX_D_TERM(
amrex::Real const* AMREX_RESTRICT s0 = m_stencil_0.data();,
amrex::Real const* AMREX_RESTRICT s1 = m_stencil_1.data();,
amrex::Real const* AMREX_RESTRICT s2 = m_stencil_2.data();
)
for (int n = 0; n < ncomp; ++n) {
// Set dst value to 0.
for (int k = lo.z; k <= hi.z; ++k) {
Expand All @@ -295,41 +288,32 @@ void Filter::DoFilter (const Box& tbx,
}
}
// 3 nested loop on 3D stencil
for (int iz=0; iz < slen.z; ++iz){
for (int iy=0; iy < slen.y; ++iy){
for (int ix=0; ix < slen.x; ++ix){
#if defined(WARPX_DIM_3D)
const Real sss = sx[ix]*sy[iy]*sz[iz];
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
const Real sss = sx[ix]*sz[iy];
#else
const Real sss = sz[ix];
#endif
for (int i2=0; i2 < slen.z; ++i2){
for (int i1=0; i1 < slen.y; ++i1){
for (int i0=0; i0 < slen.x; ++i0){
const Real sss = AMREX_D_TERM(s0[i0], *s1[i1], *s2[i2]);
// 3 nested loop on 3D array
for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
#if defined(WARPX_DIM_3D)
dst(i,j,k,dcomp+n) += sss*(tmp(i-ix,j-iy,k-iz,scomp+n)
+tmp(i+ix,j-iy,k-iz,scomp+n)
+tmp(i-ix,j+iy,k-iz,scomp+n)
+tmp(i+ix,j+iy,k-iz,scomp+n)
+tmp(i-ix,j-iy,k+iz,scomp+n)
+tmp(i+ix,j-iy,k+iz,scomp+n)
+tmp(i-ix,j+iy,k+iz,scomp+n)
+tmp(i+ix,j+iy,k+iz,scomp+n));
#elif defined(WARPX_DIM_XZ) || defined(WARPX_DIM_RZ)
dst(i,j,k,dcomp+n) += sss*(tmp(i-ix,j-iy,k,scomp+n)
+tmp(i+ix,j-iy,k,scomp+n)
+tmp(i-ix,j+iy,k,scomp+n)
+tmp(i+ix,j+iy,k,scomp+n));
#elif defined(WARPX_DIM_1D_Z)
dst(i,j,k,dcomp+n) += sss*(tmp(i-ix,j,k,scomp+n)
+tmp(i+ix,j,k,scomp+n));
#else
WARPX_ABORT_WITH_MESSAGE(
"Filter not implemented for the current geometry!");
#if AMREX_SPACEDIM == 3
dst(i,j,k,dcomp+n) += sss*(tmp(i-i0,j-i1,k-i2,scomp+n)
+tmp(i+i0,j-i1,k-i2,scomp+n)
+tmp(i-i0,j+i1,k-i2,scomp+n)
+tmp(i+i0,j+i1,k-i2,scomp+n)
+tmp(i-i0,j-i1,k+i2,scomp+n)
+tmp(i+i0,j-i1,k+i2,scomp+n)
+tmp(i-i0,j+i1,k+i2,scomp+n)
+tmp(i+i0,j+i1,k+i2,scomp+n));
#elif AMREX_SPACEDIM == 2
dst(i,j,k,dcomp+n) += sss*(tmp(i-i0,j-i1,k,scomp+n)
+tmp(i+i0,j-i1,k,scomp+n)
+tmp(i-i0,j+i1,k,scomp+n)
+tmp(i+i0,j+i1,k,scomp+n));
#elif AMREX_SPACEDIM == 1
dst(i,j,k,dcomp+n) += sss*(tmp(i-i0,j,k,scomp+n)
+tmp(i+i0,j,k,scomp+n));
#endif
}
}
Expand Down
19 changes: 10 additions & 9 deletions Source/Filter/NCIGodfreyFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,18 @@ void NCIGodfreyFilter::ComputeStencils()
# endif
h_stencil_z[0] /= 2._rt;

stencil_x.resize(h_stencil_x.size());
m_stencil_0.resize(h_stencil_x.size());
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_x.begin(),h_stencil_x.end(),m_stencil_0.begin());
# if defined(WARPX_DIM_3D)
stencil_y.resize(h_stencil_y.size());
m_stencil_1.resize(h_stencil_y.size());
m_stencil_2.resize(h_stencil_z.size());
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_y.begin(),h_stencil_y.end(),m_stencil_1.begin());
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_z.begin(),h_stencil_z.end(),m_stencil_2.begin());
# elif (AMREX_SPACEDIM == 2)
// In 2D, the filter applies stencil_1 to the 2nd dimension
m_stencil_1.resize(h_stencil_z.size());
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_z.begin(),h_stencil_z.end(),m_stencil_1.begin());
# endif
stencil_z.resize(h_stencil_z.size());

Gpu::copyAsync(Gpu::hostToDevice,h_stencil_x.begin(),h_stencil_x.end(),stencil_x.begin());
# if defined(WARPX_DIM_3D)
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_y.begin(),h_stencil_y.end(),stencil_y.begin());
# endif
Gpu::copyAsync(Gpu::hostToDevice,h_stencil_z.begin(),h_stencil_z.end(),stencil_z.begin());

Gpu::synchronize();
}
Expand Down

0 comments on commit 32737be

Please sign in to comment.