Skip to content

Commit

Permalink
Some small cleanup of clover interface functions
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Dec 7, 2023
1 parent 3aed5e0 commit c310d9c
Showing 1 changed file with 55 additions and 105 deletions.
160 changes: 55 additions & 105 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4520,47 +4520,32 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double
fParam.setPrecision(fParam.Precision(), true);
GaugeField cudaForce(fParam);

ColorSpinorParam qParam;
qParam.location = QUDA_CUDA_FIELD_LOCATION;
qParam.nColor = 3;
qParam.nSpin = 4;
qParam.siteSubset = QUDA_FULL_SITE_SUBSET;
qParam.siteOrder = QUDA_EVEN_ODD_SITE_ORDER;
qParam.nDim = 4;
ColorSpinorParam qParam(nullptr, *inv_param, fParam.x, false, QUDA_CUDA_FIELD_LOCATION);
qParam.setPrecision(fParam.Precision(), fParam.Precision(), true);
for(int dir=0; dir<4; ++dir) qParam.x[dir] = fParam.x[dir];

// create the device quark field
qParam.create = QUDA_NULL_FIELD_CREATE;
qParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS;

std::vector<ColorSpinorField> quarkX, quarkP;
std::vector<ColorSpinorField> x(nvector), p(nvector);
for (int i = 0; i < nvector; i++) {
quarkX[i] = ColorSpinorField(qParam);
quarkP[i] = ColorSpinorField(qParam);
}

// for downloading x_e
qParam.siteSubset = QUDA_PARITY_SITE_SUBSET;
qParam.x[0] /= 2;
p[i] = ColorSpinorField(qParam);
x[i] = ColorSpinorField(qParam);

// create the host quark field
qParam.location = QUDA_CPU_FIELD_LOCATION;
qParam.create = QUDA_REFERENCE_FIELD_CREATE;
qParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER;
qParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // need expose this to interface
if (!inv_param->use_resident_solution) {
ColorSpinorParam cpuParam(h_x[i], *inv_param, fParam.x, true, inv_param->input_location);
ColorSpinorField cpuQuarkX(cpuParam);
x[i].Even() = cpuQuarkX;
gamma5(x[i].Even(), x[i].Even());
} else {
x[i].Even() = solutionResident[i];
}
}

bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) ||
(inv_param->solve_type == QUDA_NORMOP_PC_SOLVE);
DiracParam diracParam;
setDiracParam(diracParam, inv_param, pc_solve);
setDiracParam(diracParam, inv_param, true);
Dirac *dirac = Dirac::create(diracParam);

if (inv_param->use_resident_solution) {
if (solutionResident.size() < (unsigned int)nvector)
errorQuda("solutionResident.size() %lu does not match number of shifts %d",
solutionResident.size(), nvector);
}
if (inv_param->use_resident_solution && solutionResident.size() < (unsigned int)nvector)
errorQuda("solutionResident.size() %lu does not match number of shifts %d", solutionResident.size(), nvector);

// create oprod and trace fields
fParam.geometry = QUDA_TENSOR_GEOMETRY;
Expand All @@ -4570,33 +4555,20 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double

profileCloverForce.TPSTART(QUDA_PROFILE_COMPUTE);
// loop over different quark fields
for(int i=0; i<nvector; i++){
ColorSpinorField &x = quarkX[i];
ColorSpinorField &p = quarkP[i];

if (!inv_param->use_resident_solution) {
// Wrap the even-parity MILC quark field
qParam.v = h_x[i];
ColorSpinorField cpuQuarkX(qParam); // create host quark field
x.Even() = cpuQuarkX;
gamma5(x.Even(), x.Even());
} else {
x.Even() = solutionResident[i];
}
for (int i = 0; i < nvector; i++) {
force_coeff[i] = 2.0 * dt * coeff[i] * kappa2;

dirac->Dslash(x.Odd(), x.Even(), QUDA_ODD_PARITY);
dirac->M(p.Even(), x.Even());
dirac->Dslash(x[i].Odd(), x[i].Even(), QUDA_ODD_PARITY);
dirac->M(p[i].Even(), x[i].Even());
dirac->Dagger(QUDA_DAG_YES);
dirac->Dslash(p.Odd(), p.Even(), QUDA_ODD_PARITY);
dirac->Dslash(p[i].Odd(), p[i].Even(), QUDA_ODD_PARITY);
dirac->Dagger(QUDA_DAG_NO);

gamma5(x, x);
gamma5(p, p);

force_coeff[i] = 2.0*dt*coeff[i]*kappa2;
gamma5(x[i], x[i]);
gamma5(p[i], p[i]);
}

computeCloverForce(cudaForce, *gaugePrecise, quarkX, quarkP, force_coeff);
computeCloverForce(cudaForce, *gaugePrecise, x, p, force_coeff);

// Make sure extendedGaugeResident has the correct R
// TODO: In most situation, deallocation is unnecessery
Expand All @@ -4610,7 +4582,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double
std::vector< array<double, 2> > ferm_epsilon(nvector);
for (int i = 0; i < nvector; i++) ferm_epsilon[i] = {2.0*ck*coeff[i]*dt, -kappa2 * 2.0*ck*coeff[i]*dt};

computeCloverSigmaOprod(oprod, quarkX, quarkP, ferm_epsilon);
computeCloverSigmaOprod(oprod, x, p, ferm_epsilon);

cloverDerivative(cudaForce, gaugeEx, oprod, 1.0);

Expand Down Expand Up @@ -4667,43 +4639,35 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef
gParamMom.setPrecision(gParamMom.Precision(), true);
GaugeField cudaForce(gParamMom);

ColorSpinorParam qParam;
qParam.location = QUDA_CUDA_FIELD_LOCATION;
qParam.nColor = 3;
qParam.nSpin = 4;
qParam.siteSubset = QUDA_FULL_SITE_SUBSET;
qParam.siteOrder = QUDA_EVEN_ODD_SITE_ORDER;
qParam.nDim = 4;
ColorSpinorParam qParam(nullptr, *inv_param, gParamMom.x, false, QUDA_CUDA_FIELD_LOCATION);
qParam.setPrecision(gauge_param->cuda_prec, gauge_param->cuda_prec, true);
qParam.twistFlavor = inv_param->twist_flavor;
qParam.pc_type = inv_param->dslash_type == QUDA_DOMAIN_WALL_DSLASH ? QUDA_5D_PC : QUDA_4D_PC;
for(int dir = 0; dir<4; ++dir) qParam.x[dir] = gParamMom.x[dir];

// create the device quark field
qParam.create = QUDA_NULL_FIELD_CREATE;
qParam.gammaBasis = QUDA_UKQCD_GAMMA_BASIS;

std::vector<ColorSpinorField> quarkX(nvector), quarkP(nvector), quarkX0(nvector);
std::vector<ColorSpinorField> x(nvector), p(nvector), x0(nvector);
for (int i = 0; i < nvector; i++) {
quarkX[i] = ColorSpinorField(qParam);
quarkP[i] = ColorSpinorField(qParam);
if (detratio) quarkX0[i] = ColorSpinorField(qParam);
p[i] = ColorSpinorField(qParam);

x[i] = ColorSpinorField(qParam);
ColorSpinorParam cpuParam(h_x[i], *inv_param, gParamMom.x, true, inv_param->input_location);
ColorSpinorField cpuQuarkX(cpuParam);
x[i].Odd() = cpuQuarkX; // in tmLQCD-parlance this is the odd part of X

if (detratio) {
x0[i] = ColorSpinorField(qParam);

ColorSpinorParam cpuParam0(h_x0[i], *inv_param, gParamMom.x, true, inv_param->input_location);
ColorSpinorField cpuQuarkX0(cpuParam0);
x0[i].Odd() = cpuQuarkX0;
}
}

qParam.siteSubset = QUDA_PARITY_SITE_SUBSET;
qParam.x[0] /= 2;
ColorSpinorField tmp(qParam);

// create the host quark field
qParam.location = QUDA_CPU_FIELD_LOCATION;
qParam.create = QUDA_REFERENCE_FIELD_CREATE;
qParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER;
qParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS;

bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) ||
(inv_param->solve_type == QUDA_NORMOP_PC_SOLVE);
DiracParam diracParam;
setDiracParam(diracParam, inv_param, pc_solve);
setDiracParam(diracParam, inv_param, true);
Dirac *dirac = Dirac::create(diracParam);

// Make sure extendedGaugeResident has the correct R
Expand All @@ -4721,54 +4685,40 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef
for (int i = 0; i < nvector; i++) {
force_coeff[i] = 1.0 * coeff[i];

ColorSpinorField &x = quarkX[i];
ColorSpinorField &p = quarkP[i];

const auto &gauge = (inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise;
ColorSpinorParam cpuParam(h_x[i], *inv_param, gauge.X(), true, inv_param->input_location);
ColorSpinorField cpuQuarkX(cpuParam);

x.Odd() = cpuQuarkX; // in tmLQCD-parlance this is the odd part of X

if (inv_param->matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC || inv_param->matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
dirac->Dagger(QUDA_DAG_YES);
gamma5(tmp, x.Odd());
dirac->Dslash(x.Even(), tmp, QUDA_EVEN_PARITY);
gamma5(x.Even(), x.Even());
gamma5(tmp, x[i].Odd());
dirac->Dslash(x[i].Even(), tmp, QUDA_EVEN_PARITY);
gamma5(x[i].Even(), x[i].Even());

// want to apply \hat Q_{-} = \hat M_{+}^\dagger \gamma_5 to get Y_o
dirac->Dagger(QUDA_DAG_YES);
dirac->M(p.Odd(), tmp); // this is the odd part of Y
dirac->M(p[i].Odd(), tmp); // this is the odd part of Y
dirac->Dagger(QUDA_DAG_NO);

if (detratio){
ColorSpinorParam cpuParam0(h_x0[i], *inv_param, gauge.X(), true, inv_param->input_location);
ColorSpinorField cpuQuarkX0(cpuParam0);
ColorSpinorField &x0 = quarkX0[i];
x0.Odd()=cpuQuarkX0;
blas::axpbyz(1, p.Odd(), 1, x0.Odd(), p.Odd());
}
dirac->Dslash(p.Even(), p.Odd(), QUDA_EVEN_PARITY); // and now the even part of Y
if (detratio) blas::xpy(x0[i].Odd(), p[i].Odd());

dirac->Dslash(p[i].Even(), p[i].Odd(), QUDA_EVEN_PARITY); // and now the even part of Y
// up to here x.odd match X.odd in tmLQCD and p.odd=-Y.odd of tmLQCD
// x.Even= X.Even.tmLQCD*kappa and p.Even=-Y.Even.tmLQCD*kappa

// the gamma5 application in tmLQCD is done inside deriv_Sb
gamma5(p, p);
// the gamma5 application in tmLQCD is done inside deriv_Sb
gamma5(p[i], p[i]);
} else {
errorQuda("computeTMCloverForceQuda: MATPC type not supported");
errorQuda("MatPC type %d not supported", inv_param->matpc_type);
}
}

// derivative of the wilson operator it correspond to deriv_Sb(OE,...) plus deriv_Sb(EO,...) in tmLQCD
computeCloverForce(cudaForce, *gaugePrecise, quarkX, quarkP, force_coeff);
computeCloverForce(cudaForce, *gaugePrecise, x, p, force_coeff);
// derivative of the determinant of the sw term, second term of (A12) in hep-lat/0112051, sw_deriv(EE, mnl->mu) in tmLQCD
if (!detratio) computeCloverSigmaTrace(oprod, *cloverPrecise, k_csw_ov_8 * 32.0, 0);

std::vector< array<double, 2> > ferm_epsilon(nvector);
for (int i = 0; i < nvector; i++) ferm_epsilon[i] = { k_csw_ov_8 * coeff[i], k_csw_ov_8 * coeff[i]/(kappa*kappa) };

// derivative of pseudofermion sw term, first term term of (A12) in hep-lat/0112051, sw_spinor_eo(EE,..) plus sw_spinor_eo(OO,..) in tmLQCD
computeCloverSigmaOprod(oprod, quarkP, quarkX, ferm_epsilon);
computeCloverSigmaOprod(oprod, p, x, ferm_epsilon);

// oprod = (A12) of hep-lat/0112051
// compute the insertion of oprod in Fig.27 of hep-lat/0112051
Expand Down

0 comments on commit c310d9c

Please sign in to comment.