Skip to content

Commit

Permalink
Clean up FFT functions for J, rho
Browse files Browse the repository at this point in the history
  • Loading branch information
EZoni committed Sep 20, 2024
1 parent c2b481a commit 98f9670
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 63 deletions.
24 changes: 12 additions & 12 deletions Source/Evolve/WarpXEvolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,16 +677,18 @@ WarpX::OneStep_multiJ (const amrex::Real cur_time)
// (after checking that pointer to rho_fp on MR level 0 is not null)
if (m_fields.has("rho_fp", 0) && rho_in_time == RhoInTime::Linear)
{
const ablastr::fields::MultiLevelScalarField rho_fp = m_fields.get_mr_levels("rho_fp", finest_level);
const ablastr::fields::MultiLevelScalarField rho_cp = m_fields.get_mr_levels("rho_cp", finest_level);
ablastr::fields::MultiLevelScalarField const rho_fp = m_fields.get_mr_levels("rho_fp", finest_level);

std::string const rho_fp_string = "rho_fp";
std::string const rho_cp_string = "rho_cp";

// Deposit rho at relative time -dt
// (dt[0] denotes the time step on mesh refinement level 0)
mypc->DepositCharge(rho_fp, -dt[0]);
// Filter, exchange boundary, and interpolate across levels
SyncRho();
// Forward FFT of rho
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_new);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_new);
}

// 4) Deposit J at relative time -dt with time step dt
Expand All @@ -702,9 +704,7 @@ WarpX::OneStep_multiJ (const amrex::Real cur_time)
// of guard cells.
SyncCurrent("current_fp");
// Forward FFT of J
PSATDForwardTransformJ(
m_fields.get_mr_levels_alldirs( "current_fp", finest_level),
m_fields.get_mr_levels_alldirs( "current_cp", finest_level) );
PSATDForwardTransformJ("current_fp", "current_cp");
}

// Number of depositions for multi-J scheme
Expand Down Expand Up @@ -738,16 +738,16 @@ WarpX::OneStep_multiJ (const amrex::Real cur_time)
// of guard cells.
SyncCurrent("current_fp");
// Forward FFT of J
PSATDForwardTransformJ(
m_fields.get_mr_levels_alldirs( "current_fp", finest_level),
m_fields.get_mr_levels_alldirs( "current_cp", finest_level) );
PSATDForwardTransformJ("current_fp", "current_cp");

// Deposit new rho
// (after checking that pointer to rho_fp on MR level 0 is not null)
if (m_fields.has("rho_fp", 0))
{
const ablastr::fields::MultiLevelScalarField rho_fp = m_fields.get_mr_levels("rho_fp", finest_level);
const ablastr::fields::MultiLevelScalarField rho_cp = m_fields.get_mr_levels("rho_cp", finest_level);
ablastr::fields::MultiLevelScalarField const rho_fp = m_fields.get_mr_levels("rho_fp", finest_level);

std::string const rho_fp_string = "rho_fp";
std::string const rho_cp_string = "rho_cp";

// Move rho from new to old if rho is linear in time
if (rho_in_time == RhoInTime::Linear) { PSATDMoveRhoNewToRhoOld(); }
Expand All @@ -758,7 +758,7 @@ WarpX::OneStep_multiJ (const amrex::Real cur_time)
SyncRho();
// Forward FFT of rho
const int rho_idx = (rho_in_time == RhoInTime::Linear) ? rho_new : rho_mid;
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_idx);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_idx);
}

if (WarpX::current_correction)
Expand Down
116 changes: 71 additions & 45 deletions Source/FieldSolver/WarpXPushFieldsEM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,12 @@ WarpX::PSATDBackwardTransformG ()
}

void WarpX::PSATDForwardTransformJ (
const ablastr::fields::MultiLevelVectorField& J_fp,
const ablastr::fields::MultiLevelVectorField& J_cp,
std::string const & J_fp_string,
std::string const & J_cp_string,
const bool apply_kspace_filter)
{
if (!m_fields.has_vector(J_fp_string, 0)) { return; }

SpectralFieldIndex Idx;
int idx_jx, idx_jy, idx_jz;

Expand All @@ -325,7 +327,10 @@ void WarpX::PSATDForwardTransformJ (
idx_jy = (J_in_time == JInTime::Linear) ? static_cast<int>(Idx.Jy_new) : static_cast<int>(Idx.Jy_mid);
idx_jz = (J_in_time == JInTime::Linear) ? static_cast<int>(Idx.Jz_new) : static_cast<int>(Idx.Jz_mid);

ForwardTransformVect(lev, *spectral_solver_fp[lev], J_fp[lev], idx_jx, idx_jy, idx_jz);
if (m_fields.has_vector(J_fp_string, lev)) {
ablastr::fields::VectorField J_fp = m_fields.get_alldirs(J_fp_string, lev);
ForwardTransformVect(lev, *spectral_solver_fp[lev], J_fp, idx_jx, idx_jy, idx_jz);
}

if (spectral_solver_cp[lev])
{
Expand All @@ -335,7 +340,10 @@ void WarpX::PSATDForwardTransformJ (
idx_jy = (J_in_time == JInTime::Linear) ? static_cast<int>(Idx.Jy_new) : static_cast<int>(Idx.Jy_mid);
idx_jz = (J_in_time == JInTime::Linear) ? static_cast<int>(Idx.Jz_new) : static_cast<int>(Idx.Jz_mid);

ForwardTransformVect(lev, *spectral_solver_cp[lev], J_cp[lev], idx_jx, idx_jy, idx_jz);
if (m_fields.has_vector(J_cp_string, lev)) {
ablastr::fields::VectorField J_cp = m_fields.get_alldirs(J_cp_string, lev);
ForwardTransformVect(lev, *spectral_solver_cp[lev], J_cp, idx_jx, idx_jy, idx_jz);
}
}
}

Expand Down Expand Up @@ -371,9 +379,11 @@ void WarpX::PSATDForwardTransformJ (
}

void WarpX::PSATDBackwardTransformJ (
ablastr::fields::MultiLevelVectorField const & J_fp,
ablastr::fields::MultiLevelVectorField const & J_cp)
std::string const & J_fp_string,
std::string const & J_cp_string)
{
if (!m_fields.has_vector(J_fp_string, 0)) { return; }

SpectralFieldIndex Idx;
int idx_jx, idx_jy, idx_jz;

Expand All @@ -387,8 +397,11 @@ void WarpX::PSATDBackwardTransformJ (
idx_jy = static_cast<int>(Idx.Jy_mid);
idx_jz = static_cast<int>(Idx.Jz_mid);

BackwardTransformVect(lev, *spectral_solver_fp[lev], J_fp[lev],
idx_jx, idx_jy, idx_jz, m_fill_guards_current);
if (m_fields.has_vector(J_fp_string, lev)) {
ablastr::fields::VectorField J_fp = m_fields.get_alldirs(J_fp_string, lev);
BackwardTransformVect(lev, *spectral_solver_fp[lev], J_fp,
idx_jx, idx_jy, idx_jz, m_fill_guards_current);
}

if (spectral_solver_cp[lev])
{
Expand All @@ -400,26 +413,35 @@ void WarpX::PSATDBackwardTransformJ (
idx_jy = static_cast<int>(Idx.Jy_mid);
idx_jz = static_cast<int>(Idx.Jz_mid);

BackwardTransformVect(lev, *spectral_solver_cp[lev], J_cp[lev],
idx_jx, idx_jy, idx_jz, m_fill_guards_current);
if (m_fields.has_vector(J_cp_string, lev)) {
ablastr::fields::VectorField J_cp = m_fields.get_alldirs(J_cp_string, lev);
BackwardTransformVect(lev, *spectral_solver_cp[lev], J_cp,
idx_jx, idx_jy, idx_jz, m_fill_guards_current);
}
}
}
}

void WarpX::PSATDForwardTransformRho (
ablastr::fields::MultiLevelScalarField const & charge_fp,
ablastr::fields::MultiLevelScalarField const & charge_cp,
std::string const & charge_fp_string,
std::string const & charge_cp_string,
const int icomp, const int dcomp, const bool apply_kspace_filter)
{
if (charge_fp[0] == nullptr) { return; }
if (!m_fields.has(charge_fp_string, 0)) { return; }

for (int lev = 0; lev <= finest_level; ++lev)
{
if (charge_fp[lev]) { spectral_solver_fp[lev]->ForwardTransform(lev, *charge_fp[lev], dcomp, icomp); }
if (m_fields.has(charge_fp_string, lev)) {
amrex::MultiFab const & charge_fp = *m_fields.get(charge_fp_string, lev);
spectral_solver_fp[lev]->ForwardTransform(lev, charge_fp, dcomp, icomp);
}

if (spectral_solver_cp[lev])
{
if (charge_cp[lev]) { spectral_solver_cp[lev]->ForwardTransform(lev, *charge_cp[lev], dcomp, icomp); }
if (m_fields.has(charge_cp_string, lev)) {
amrex::MultiFab const & charge_cp = *m_fields.get(charge_cp_string, lev);
spectral_solver_cp[lev]->ForwardTransform(lev, charge_cp, dcomp, icomp);
}
}
}

Expand Down Expand Up @@ -699,53 +721,56 @@ WarpX::PushPSATD ()

const int rho_old = spectral_solver_fp[0]->m_spectral_index.rho_old;
const int rho_new = spectral_solver_fp[0]->m_spectral_index.rho_new;
const ablastr::fields::MultiLevelScalarField rho_fp = m_fields.get_mr_levels("rho_fp", finest_level);
const ablastr::fields::MultiLevelScalarField rho_cp = m_fields.get_mr_levels("rho_cp", finest_level);

std::string const rho_fp_string = "rho_fp";
std::string const rho_cp_string = "rho_cp";

const ablastr::fields::MultiLevelVectorField current_fp = m_fields.get_mr_levels_alldirs("current_fp", finest_level);
const ablastr::fields::MultiLevelVectorField current_cp = m_fields.get_mr_levels_alldirs("current_cp", finest_level);
const ablastr::fields::MultiLevelVectorField current_buf = m_fields.get_mr_levels_alldirs("current_buf", finest_level);
std::string current_fp_string = "current_fp";
std::string const current_cp_string = "current_cp";

if (fft_periodic_single_box)
{
if (current_correction)
{
// FFT of J and rho
PSATDForwardTransformJ(current_fp, current_cp);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new);
PSATDForwardTransformJ(current_fp_string, current_cp_string);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new);

// Correct J in k-space
PSATDCurrentCorrection();

// Inverse FFT of J
PSATDBackwardTransformJ(current_fp, current_cp);
PSATDBackwardTransformJ(current_fp_string, current_cp_string);
}
else if (current_deposition_algo == CurrentDepositionAlgo::Vay)
{
// FFT of D and rho (if used)
// TODO Replace current_cp with current_cp_vay once Vay deposition is implemented with MR
PSATDForwardTransformJ(
m_fields.get_mr_levels_alldirs("current_fp_vay", finest_level), current_cp);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new);
current_fp_string = "current_fp_vay";
PSATDForwardTransformJ(current_fp_string, current_cp_string);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new);

// Compute J from D in k-space
PSATDVayDeposition();

// Inverse FFT of J, subtract cumulative sums of D
PSATDBackwardTransformJ(current_fp, current_cp);
current_fp_string = "current_fp";
PSATDBackwardTransformJ(current_fp_string, current_cp_string);
// TODO Cumulative sums need to be fixed with periodic single box
PSATDSubtractCurrentPartialSumsAvg();

// FFT of J after subtraction of cumulative sums
PSATDForwardTransformJ(current_fp, current_cp);
PSATDForwardTransformJ(current_fp_string, current_cp_string);
}
else // no current correction, no Vay deposition
{
// FFT of J and rho (if used)
PSATDForwardTransformJ(current_fp, current_cp);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new);
PSATDForwardTransformJ(current_fp_string, current_cp_string);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new);
}
}
else // no periodic single box
Expand All @@ -757,20 +782,20 @@ WarpX::PushPSATD ()
// In RZ geometry, do not apply filtering here, since it is
// applied in the subsequent calls to these functions (below)
const bool apply_kspace_filter = false;
PSATDForwardTransformJ(current_fp, current_cp, apply_kspace_filter);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old, apply_kspace_filter);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new, apply_kspace_filter);
PSATDForwardTransformJ(current_fp_string, current_cp_string, apply_kspace_filter);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old, apply_kspace_filter);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new, apply_kspace_filter);
#else
PSATDForwardTransformJ(current_fp, current_cp);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new);
PSATDForwardTransformJ(current_fp_string, current_cp_string);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new);
#endif

// Correct J in k-space
PSATDCurrentCorrection();

// Inverse FFT of J
PSATDBackwardTransformJ(current_fp, current_cp);
PSATDBackwardTransformJ(current_fp_string, current_cp_string);

// Synchronize J and rho
SyncCurrent("current_fp");
Expand All @@ -779,14 +804,15 @@ WarpX::PushPSATD ()
else if (current_deposition_algo == CurrentDepositionAlgo::Vay)
{
// FFT of D
PSATDForwardTransformJ(
m_fields.get_mr_levels_alldirs("current_fp_vay", finest_level), current_cp);
current_fp_string = "current_fp_vay";
PSATDForwardTransformJ(current_fp_string, current_cp_string);

// Compute J from D in k-space
PSATDVayDeposition();

// Inverse FFT of J, subtract cumulative sums of D
PSATDBackwardTransformJ(current_fp, current_cp);
current_fp_string = "current_fp";
PSATDBackwardTransformJ(current_fp_string, current_cp_string);
PSATDSubtractCurrentPartialSumsAvg();

// Synchronize J and rho (if used).
Expand All @@ -800,9 +826,9 @@ WarpX::PushPSATD ()
}

// FFT of J and rho (if used)
PSATDForwardTransformJ(current_fp, current_cp);
PSATDForwardTransformRho(rho_fp, rho_cp, 0, rho_old);
PSATDForwardTransformRho(rho_fp, rho_cp, 1, rho_new);
PSATDForwardTransformJ(current_fp_string, current_cp_string);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 0, rho_old);
PSATDForwardTransformRho(rho_fp_string, rho_cp_string, 1, rho_new);
}

// FFT of E and B
Expand Down
12 changes: 6 additions & 6 deletions Source/WarpX.H
Original file line number Diff line number Diff line change
Expand Up @@ -1685,8 +1685,8 @@ private:
* (only used in RZ geometry to avoid double filtering)
*/
void PSATDForwardTransformJ (
ablastr::fields::MultiLevelVectorField const& J_fp,
ablastr::fields::MultiLevelVectorField const& J_cp,
std::string const & J_fp_string,
std::string const & J_cp_string,
bool apply_kspace_filter=true);

/**
Expand All @@ -1698,8 +1698,8 @@ private:
* storing the coarse patch current to be transformed
*/
void PSATDBackwardTransformJ (
ablastr::fields::MultiLevelVectorField const & J_fp,
ablastr::fields::MultiLevelVectorField const & J_cp);
std::string const & J_fp_string,
std::string const & J_cp_string);

/**
* \brief Forward FFT of rho on all mesh refinement levels,
Expand All @@ -1713,8 +1713,8 @@ private:
* (only used in RZ geometry to avoid double filtering)
*/
void PSATDForwardTransformRho (
ablastr::fields::MultiLevelScalarField const & charge_fp,
ablastr::fields::MultiLevelScalarField const & charge_cp,
std::string const & charge_fp_string,
std::string const & charge_cp_string,
int icomp, int dcomp, bool apply_kspace_filter=true);

/**
Expand Down

0 comments on commit 98f9670

Please sign in to comment.