Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable complex coefficients for SelfHealingOverlap estimator #5291

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 70 additions & 26 deletions src/Estimators/SelfHealingOverlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "SelfHealingOverlap.h"
#include "TrialWaveFunction.h"
#include "QMCWaveFunctions/Fermion/MultiSlaterDetTableMethod.h"
#include "QMCWaveFunctions/Fermion/SlaterDet.h"

#include <iostream>
#include <numeric>
Expand All @@ -19,19 +20,47 @@
namespace qmcplusplus
{
SelfHealingOverlap::SelfHealingOverlap(SelfHealingOverlapInput&& inp_, const TrialWaveFunction& wfn, DataLocality dl)
: OperatorEstBase(dl), input_(std::move(inp_))
: OperatorEstBase(dl),
input_(std::move(inp_)),
wf_type(no_wf),
use_param_deriv(input_.input_section_.get<bool>("param_deriv"))
{
//my_name_ = input_.get_name();

auto& inp = this->input_.input_section_;

auto msd_refvec = wfn.findMSD();
if (msd_refvec.size() != 1)
auto sd_refvec = wfn.findSD();

auto nsd = sd_refvec.size();
auto nmsd = msd_refvec.size();

size_t nparams;
if (nmsd == 1 and nsd == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ // multi-slater-det wavefunction
wf_type = msd_wf;
const MultiSlaterDetTableMethod& msd = msd_refvec[0];
if (!use_param_deriv)
nparams = msd.getLinearExpansionCoefs().size();
else
{
throw std::runtime_error("SelfHealingOverlap: use_param_deriv implementation incomplete, needs access to param "

Check warning on line 43 in src/Estimators/SelfHealingOverlap.cpp

View check run for this annotation

Codecov / codecov/patch

src/Estimators/SelfHealingOverlap.cpp#L43

Added line #L43 was not covered by tests
"count from wavefunction component myVars");
}
if (nparams == 0)
throw std::runtime_error("SelfHealingOverlap: multidet wavefunction has no parameters.");
}
else if (nmsd == 0 and nsd == 1)
{ // slater-det wavefunction
throw std::runtime_error("SelfHealingOverlap: slaterdet wavefunction implementation incomplete");
}
else
{
throw std::runtime_error(
"SelfHealingOverlap requires one and only one multi slater determinant component in the trial wavefunction.");
"SelfHealingOverlap requires a single slater or multi-slater determinant component in the trial wavefunction.");
}

const MultiSlaterDetTableMethod& msd = msd_refvec[0];
const size_t data_size = msd.getLinearExpansionCoefs().size();
#ifndef QMC_COMPLEX
const size_t data_size = nparams;
#else
const size_t data_size = 2 * nparams;

Check warning on line 62 in src/Estimators/SelfHealingOverlap.cpp

View check run for this annotation

Codecov / codecov/patch

src/Estimators/SelfHealingOverlap.cpp#L62

Added line #L62 was not covered by tests
#endif
data_.resize(data_size, 0.0);
}

Expand Down Expand Up @@ -78,39 +107,54 @@
RealType weight = walker.Weight;
auto& wcs = psi.getOrbitals();

// separate jastrow and fermi wavefunction components
// find jastrow wavefunction components
std::vector<WaveFunctionComponent*> wcs_jastrow;
std::vector<WaveFunctionComponent*> wcs_fermi;
for (auto& wc : wcs)
if (wc->isFermionic())
wcs_fermi.push_back(wc.get());
else
if (!wc->isFermionic())
wcs_jastrow.push_back(wc.get());

// fermionic must have only one component, and must be multideterminant
assert(wcs_fermi.size() == 1);
WaveFunctionComponent& wf = *wcs_fermi[0];
if (!wf.isMultiDet())
throw std::runtime_error("SelfHealingOverlap estimator requires use of multideterminant wavefunction");
auto msd_refvec = psi.findMSD();
MultiSlaterDetTableMethod& msd = msd_refvec[0];

// collect parameter derivatives: (dpsi/dc_i)/psi
msd.calcIndividualDetRatios(det_ratios);
if (wf_type == msd_wf)
{
auto msd_refvec = psi.findMSD();
MultiSlaterDetTableMethod& msd = msd_refvec[0];
// collect parameter derivatives: (dpsi/dc_i)/psi
if (!use_param_deriv)
msd.calcIndividualDetRatios(det_ratios);
else
{
throw std::runtime_error("SelfHealingOverlap: use_param_deriv implementation incomplete, needs call to "

Check warning on line 125 in src/Estimators/SelfHealingOverlap.cpp

View check run for this annotation

Codecov / codecov/patch

src/Estimators/SelfHealingOverlap.cpp#L125

Added line #L125 was not covered by tests
"msd.evaluateDerivatives with correct myVars");
}
}
else if (wf_type == sd_rot_wf)
{
throw std::runtime_error("SelfHealingOverlap: slaterdet wavefunction implementation incomplete");
auto sd_refvec = psi.findSD();
}
else
throw std::runtime_error("SelfHealingOverlap: impossible branch reached, contact the developers");

// collect jastrow prefactor
WaveFunctionComponent::LogValue Jval = 0.0;
for (auto& wc : wcs_jastrow)
Jval += wc->get_log_value();
auto Jprefactor = std::real(std::exp(-2. * Jval));
auto Jprefactor = std::exp(-2. * Jval);

// accumulate weight (required by all estimators, otherwise inf results)
walkers_weight_ += weight;

// accumulate data
assert(det_ratios.size() == data_.size());
for (int ic = 0; ic < det_ratios.size(); ++ic)
data_[ic] += weight * Jprefactor * std::real(det_ratios[ic]); // only real supported for now
{
#ifndef QMC_COMPLEX
data_[ic] += weight * std::real(Jprefactor) * det_ratios[ic];
#else
auto value = weight * Jprefactor * std::conj(det_ratios[ic]);
data_[2 * ic] += std::real(value);
data_[2 * ic + 1] += std::imag(value);

Check warning on line 155 in src/Estimators/SelfHealingOverlap.cpp

View check run for this annotation

Codecov / codecov/patch

src/Estimators/SelfHealingOverlap.cpp#L153-L155

Added lines #L153 - L155 were not covered by tests
#endif
}
}
}

Expand Down
15 changes: 15 additions & 0 deletions src/Estimators/SelfHealingOverlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,28 @@ class SelfHealingOverlap : public OperatorEstBase
using ValueType = QMCTraits::ValueType;
using PosType = QMCTraits::PosType;


enum wf_types
{
msd_wf = 0,
sd_rot_wf,
no_wf
};

//data members set only during construction
const SelfHealingOverlapInput input_;

/** @ingroup SelfHealingOverlap mutable data members
*/
Vector<ValueType> det_ratios;

/// wavefunction type
wf_types wf_type;

/// use direct parameter derivative for MSD or not
const bool use_param_deriv;


public:
/** Constructor for SelfHealingOverlapInput
*/
Expand Down
9 changes: 6 additions & 3 deletions src/Estimators/SelfHealingOverlapInput.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SelfHealingOverlapInput
{
public:
using Consumer = SelfHealingOverlap;
using Real = QMCTraits::RealType;
using Real = QMCTraits::RealType;

class SelfHealingOverlapInputSection : public InputSection
{
Expand All @@ -32,9 +32,12 @@ class SelfHealingOverlapInput
SelfHealingOverlapInputSection()
{
section_name = "SelfHealingOverlap";
attributes = {"type", "name"};
attributes = {"type", "name", "param_deriv"};
strings = {"type", "name"};
default_values = {{"type", std::string("sh_overlap")},{"name", std::string("sh_overlap")}};
bools = {"param_deriv"};
default_values = {{"type", std::string("sh_overlap")},
{"name", std::string("sh_overlap")},
{"param_deriv", false}};
}
// clang-format: on
};
Expand Down
10 changes: 10 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "Concurrency/Info.hpp"
#include "type_traits/ConvertToReal.h"
#include "NaNguard.h"
#include "Fermion/SlaterDet.h"
#include "Fermion/MultiSlaterDetTableMethod.h"

namespace qmcplusplus
Expand Down Expand Up @@ -107,6 +108,15 @@ const SPOSet& TrialWaveFunction::getSPOSet(const std::string& name) const
return *spoit->second;
}

RefVector<SlaterDet> TrialWaveFunction::findSD() const
{
RefVector<SlaterDet> refs;
for (auto& component : Z)
if (auto* comp_ptr = dynamic_cast<SlaterDet*>(component.get()); comp_ptr)
Copy link
Contributor

@PDoakORNL PDoakORNL Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

findSD is seems too brief, this is now part of the public API and it would be nice if it was recognizable as what it is from that. Imagine I want to search the codebase for SlaterDet definitions

auto my_det = twf.findSD()

SD is not a character combination of much specificity.

This smells a bit, dynamic_casts aren't zero cost

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry about that. I'd like to addressed it separately.
findSD/findMSD are too specific. I'd like to change them as mentioned #5291 (comment)

refs.push_back(*comp_ptr);
return refs;
}

RefVector<MultiSlaterDetTableMethod> TrialWaveFunction::findMSD() const
{
RefVector<MultiSlaterDetTableMethod> refs;
Expand Down
4 changes: 4 additions & 0 deletions src/QMCWaveFunctions/TrialWaveFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

namespace qmcplusplus
{
class SlaterDet;
class MultiSlaterDetTableMethod;

/** @ingroup MBWfs
Expand Down Expand Up @@ -539,6 +540,9 @@ class TrialWaveFunction
/// spomap_ reference accessor
const SPOMap& getSPOMap() const { return *spomap_; }

/// find SD WFCs if exist
RefVector<SlaterDet> findSD() const;
ye-luo marked this conversation as resolved.
Show resolved Hide resolved

/// find MSD WFCs if exist
RefVector<MultiSlaterDetTableMethod> findMSD() const;

Expand Down
Loading