Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: add CalAtomsInfo to modify parameter #5132

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
75 changes: 75 additions & 0 deletions source/module_cell/cal_atoms_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#ifndef CAL_ATOMS_INFO_H
#define CAL_ATOMS_INFO_H
#include "module_parameter/parameter.h"
#include "unitcell.h"
class CalAtomsInfo
{
public:
CalAtomsInfo(){};
~CalAtomsInfo(){};

/**
* @brief Calculate the atom information from pseudopotential to set Parameter
*
* @param atoms [in] Atom pointer
* @param ntype [in] number of atom types
* @param para [out] Parameter object
*/
void cal_atoms_info(const Atom* atoms, const int& ntype, Parameter& para)
{
// calculate initial total magnetization when NSPIN=2
if (para.inp.nspin == 2 && !para.globalv.two_fermi)
{
for (int it = 0; it < ntype; ++it)
{
for (int ia = 0; ia < atoms[it].na; ++ia)
{
GlobalV::nupdown += atoms[it].mag[ia];
}
}
GlobalV::ofs_running << " The readin total magnetization is " << GlobalV::nupdown << std::endl;
}

if (!para.inp.use_paw)
{
// decide whether to be USPP
for (int it = 0; it < ntype; ++it)
{
if (atoms[it].ncpp.tvanp)
{
GlobalV::use_uspp = true;
}
}

// calculate the total number of local basis
GlobalV::NLOCAL = 0;
for (int it = 0; it < ntype; ++it)
{
const int nlocal_it = atoms[it].nw * atoms[it].na;
if (para.inp.nspin != 4)
{
GlobalV::NLOCAL += nlocal_it;
}
else
{
GlobalV::NLOCAL += nlocal_it * 2; // zhengdy-soc
}
}
}

// calculate the total number of electrons
cal_nelec(atoms, ntype, GlobalV::nelec);

// autoset and check GlobalV::NBANDS
std::vector<double> nelec_spin(2, 0.0);
if (para.inp.nspin == 2)
{
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
}
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);

return;
}
};
#endif
4 changes: 2 additions & 2 deletions source/module_cell/module_neighbor/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in,
coor_type(coor_type_in),
coordinates(coordinates_in)
{
mbl = {0};
velocity = {0};
mbl = std::valarray<double>(0.0, coordinates_in.size());
velocity = std::valarray<double>(0.0, coordinates_in.size());
}

UcellTestPrepare::UcellTestPrepare(std::string latname_in,
Expand Down
4 changes: 2 additions & 2 deletions source/module_cell/test/prepare_unitcell.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ UcellTestPrepare::UcellTestPrepare(std::string latname_in,
coor_type(coor_type_in),
coordinates(coordinates_in)
{
mbl = {0};
velocity = {0};
mbl = std::valarray<double>(0.0, coordinates_in.size());
velocity = std::valarray<double>(0.0, coordinates_in.size());
}

UcellTestPrepare::UcellTestPrepare(std::string latname_in,
Expand Down
2 changes: 1 addition & 1 deletion source/module_cell/test/support/mock_unitcell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,5 @@ void UnitCell::setup(const std::string& latname_in,
const int& lmaxmax_in,
const bool& init_vel_in,
const std::string& fixed_axes_in) {}
void UnitCell::cal_nelec(double& nelec) {}
void cal_nelec(const Atom* atoms, const int& ntype, double& nelec) {}
void UnitCell::compare_atom_labels(std::string label1, std::string label2) {}
170 changes: 166 additions & 4 deletions source/module_cell/test/unitcell_test_readpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ Magnetism::~Magnetism() { delete[] this->start_magnetization; }
* possible of an element
* - CalNelec: UnitCell::cal_nelec
* - calculate the total number of valence electrons from psp files
* - CalNbands: elecstate::ElecState::cal_nbands()
* - calculate the number of bands
*/

// mock function
Expand All @@ -114,9 +116,16 @@ class UcellTest : public ::testing::Test {
pp_dir = "./support/";
PARAM.input.pseudo_rcut = 15.0;
PARAM.input.dft_functional = "default";
PARAM.input.esolver_type = "ksdft";
PARAM.input.test_pseudo_cell = true;
PARAM.input.nspin = 1;
PARAM.input.basis_type = "pw";
GlobalV::nelec = 10.0;
GlobalV::nupdown = 0.0;
PARAM.sys.two_fermi = false;
GlobalV::NBANDS = 6;
GlobalV::NLOCAL = 6;
PARAM.input.lspinorb = false;
}
void TearDown() { ofs.close(); }
};
Expand Down Expand Up @@ -256,6 +265,7 @@ TEST_F(UcellTest, CalNwfc1) {
ucell->read_cell_pseudopots(pp_dir, ofs);
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
GlobalV::NLOCAL = 3 * 9;
ucell->cal_nwfc(ofs);
EXPECT_EQ(ucell->atoms[0].iw2l[8], 2);
EXPECT_EQ(ucell->atoms[0].iw2n[8], 0);
Expand All @@ -282,7 +292,6 @@ TEST_F(UcellTest, CalNwfc1) {
EXPECT_EQ(ucell->atoms[0].nw, 9);
EXPECT_EQ(ucell->atoms[1].nw, 9);
EXPECT_EQ(ucell->nwmax, 9);
EXPECT_EQ(GlobalV::NLOCAL, 3 * 9);
// check itia2iat
EXPECT_EQ(ucell->itia2iat.getSize(), 4);
EXPECT_EQ(ucell->itia2iat(0, 0), 0);
Expand Down Expand Up @@ -322,8 +331,8 @@ TEST_F(UcellTest, CalNwfc2) {
ucell->read_cell_pseudopots(pp_dir, ofs);
EXPECT_FALSE(ucell->atoms[0].ncpp.has_so);
EXPECT_FALSE(ucell->atoms[1].ncpp.has_so);
ucell->cal_nwfc(ofs);
EXPECT_EQ(GlobalV::NLOCAL, 3 * 9 * 2);
GlobalV::NLOCAL = 3 * 9 * 2;
EXPECT_NO_THROW(ucell->cal_nwfc(ofs));
}

TEST_F(UcellDeathTest, CheckStructure) {
Expand Down Expand Up @@ -396,10 +405,163 @@ TEST_F(UcellTest, CalNelec) {
EXPECT_EQ(1, ucell->atoms[0].na);
EXPECT_EQ(2, ucell->atoms[1].na);
double nelec = 0;
ucell->cal_nelec(nelec);
cal_nelec(ucell->atoms, ucell->ntype, nelec);
EXPECT_DOUBLE_EQ(6, nelec);
}

TEST_F(UcellTest, CalNbands)
{
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsFractionElec)
{
GlobalV::nelec = 9.5;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSOC)
{
PARAM.input.lspinorb = true;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 20);
}

TEST_F(UcellTest, CalNbandsSDFT)
{
PARAM.input.esolver_type = "sdft";
std::vector<double> nelec_spin(2, 5.0);
EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS));
}

TEST_F(UcellTest, CalNbandsLCAO)
{
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
EXPECT_NO_THROW(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS));
}

TEST_F(UcellTest, CalNbandsLCAOINPW)
{
PARAM.input.basis_type = "lcao_in_pw";
GlobalV::NLOCAL = GlobalV::NBANDS - 1;
std::vector<double> nelec_spin(2, 5.0);
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("NLOCAL < NBANDS"));
}

TEST_F(UcellTest, CalNbandsWarning1)
{
GlobalV::NBANDS = GlobalV::nelec / 2 - 1;
std::vector<double> nelec_spin(2, 5.0);
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few bands!"));
}

TEST_F(UcellTest, CalNbandsWarning2)
{
PARAM.input.nspin = 2;
GlobalV::nupdown = 4.0;
std::vector<double> nelec_spin(2);
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few spin up bands!"));
}

TEST_F(UcellTest, CalNbandsWarning3)
{
PARAM.input.nspin = 2;
GlobalV::nupdown = -4.0;
std::vector<double> nelec_spin(2);
nelec_spin[0] = (GlobalV::nelec + GlobalV::nupdown) / 2.0;
nelec_spin[1] = (GlobalV::nelec - GlobalV::nupdown) / 2.0;
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("Too few spin down bands!"));
}

TEST_F(UcellTest, CalNbandsSpin1)
{
PARAM.input.nspin = 1;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 15);
}

TEST_F(UcellTest, CalNbandsSpin1LCAO)
{
PARAM.input.nspin = 1;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSpin4)
{
PARAM.input.nspin = 4;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 30);
}

TEST_F(UcellTest, CalNbandsSpin4LCAO)
{
PARAM.input.nspin = 4;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsSpin2)
{
PARAM.input.nspin = 2;
GlobalV::NBANDS = 0;
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 16);
}

TEST_F(UcellTest, CalNbandsSpin2LCAO)
{
PARAM.input.nspin = 2;
GlobalV::NBANDS = 0;
PARAM.input.basis_type = "lcao";
std::vector<double> nelec_spin(2, 5.0);
cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS);
EXPECT_EQ(GlobalV::NBANDS, 6);
}

TEST_F(UcellTest, CalNbandsGaussWarning)
{
GlobalV::NBANDS = 5;
std::vector<double> nelec_spin(2, 5.0);
PARAM.input.smearing_method = "gaussian";
testing::internal::CaptureStdout();
EXPECT_EXIT(cal_nbands(GlobalV::nelec, GlobalV::NLOCAL, nelec_spin, GlobalV::NBANDS), ::testing::ExitedWithCode(0), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("for smearing, num. of bands > num. of occupied bands"));
}

#ifdef __MPI
#include "mpi.h"
int main(int argc, char** argv) {
Expand Down
Loading
Loading