From eaae12d898ea8bbab43426e0667a277126a7c0a9 Mon Sep 17 00:00:00 2001 From: William F Godoy Date: Mon, 28 Aug 2023 10:56:47 -0400 Subject: [PATCH] Specialize functions in RotatedSPOsT Fix function signature --- src/QMCWaveFunctions/RotatedSPOsT.cpp | 6 +- src/QMCWaveFunctions/RotatedSPOsT.h | 7 +- src/QMCWaveFunctions/SPOSetBuilderT.cpp | 153 ++++++++++-------------- src/QMCWaveFunctions/SPOSetT.cpp | 4 +- src/QMCWaveFunctions/SPOSetT.h | 4 +- 5 files changed, 74 insertions(+), 100 deletions(-) diff --git a/src/QMCWaveFunctions/RotatedSPOsT.cpp b/src/QMCWaveFunctions/RotatedSPOsT.cpp index 5a992ebce8..f76150ec2a 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.cpp +++ b/src/QMCWaveFunctions/RotatedSPOsT.cpp @@ -975,9 +975,9 @@ void RotatedSPOsT::evaluateDerivatives(ParticleSet& P, template void RotatedSPOsT::evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const FullRealType& psiCurrent, - const std::vector& Coeff, + Vector& dlogpsi, + const ValueType& psiCurrent, + const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, const ValueVector& detValues_up, diff --git a/src/QMCWaveFunctions/RotatedSPOsT.h b/src/QMCWaveFunctions/RotatedSPOsT.h index 3273681455..77daf7fd92 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.h +++ b/src/QMCWaveFunctions/RotatedSPOsT.h @@ -35,6 +35,7 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObject public: using IndexType = typename SPOSetT::IndexType; using RealType = typename SPOSetT::RealType; + using ValueType = typename SPOSetT::ValueType; using FullRealType = typename SPOSetT::FullRealType; using ValueVector = typename SPOSetT::ValueVector; using ValueMatrix = typename SPOSetT::ValueMatrix; @@ -200,9 +201,9 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObject void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const FullRealType& psiCurrent, - const std::vector& Coeff, + Vector& dlogpsi, + const ValueType& psiCurrent, + const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, const ValueVector& detValues_up, diff --git a/src/QMCWaveFunctions/SPOSetBuilderT.cpp b/src/QMCWaveFunctions/SPOSetBuilderT.cpp index c682d6a77a..3f0c6f4115 100644 --- a/src/QMCWaveFunctions/SPOSetBuilderT.cpp +++ b/src/QMCWaveFunctions/SPOSetBuilderT.cpp @@ -15,10 +15,7 @@ #include "SPOSetBuilderT.h" #include "OhmmsData/AttributeSet.h" #include - -#ifndef QMC_COMPLEX -#include "QMCWaveFunctions/RotatedSPOsT.h" -#endif +#include "QMCWaveFunctions/RotatedSPOsT.h" // only for real wavefunctions namespace qmcplusplus { @@ -45,96 +42,51 @@ std::unique_ptr> SPOSetBuilderT::createSPOSet(xmlNodePtr cur, SPOS return 0; } -template -std::unique_ptr> SPOSetBuilderT::createSPOSet(xmlNodePtr cur) + +template<> +std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) { std::string spo_object_name; - std::string optimize("no"); - + std::string method; OhmmsAttributeSet attrib; - attrib.add(spo_object_name, "id"); attrib.add(spo_object_name, "name"); - attrib.add(optimize, "optimize"); + attrib.add(method, "method", {"global", "history"}); attrib.put(cur); - app_summary() << std::endl; - app_summary() << " Single particle orbitals (SPO)" << std::endl; - app_summary() << " ------------------------------" << std::endl; - app_summary() << " Name: " << spo_object_name << " Type: " << type_name_ - << " Builder class name: " << ClassName << std::endl; - app_summary() << std::endl; - - if (spo_object_name.empty()) - myComm->barrier_and_abort("SPOSet object \"name\" attribute not given in the input!"); - - // read specialized sposet construction requests - // and translate them into a set of orbital indices - SPOSetInputInfo input_info(cur); - - // process general sposet construction requests - // and preserve legacy interface - std::unique_ptr> sposet; - - try - { - if (legacy && input_info.legacy_request) - sposet = createSPOSetFromXML(cur); - else - sposet = createSPOSet(cur, input_info); - } - catch (const UniformCommunicateError& ue) - { - myComm->barrier_and_abort(ue.what()); - } + std::unique_ptr> sposet; + processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { + if (cname == "sposet") + { + sposet = createSPOSet(element); + } + }); if (!sposet) - myComm->barrier_and_abort("SPOSetBuilderT::createSPOSet sposet creation failed"); - - if (optimize == "rotation" || optimize == "yes") - { -#ifdef QMC_COMPLEX - app_error() << "Orbital optimization via rotation doesn't support complex wavefunction yet.\n"; - abort(); -#else - app_warning() << "Specifying orbital rotation via optimize tag is deprecated. Use the rotated_spo element instead" - << std::endl; - - sposet->storeParamsBeforeRotation(); - // create sposet with rotation - auto& sposet_ref = *sposet; - app_log() << " SPOSet " << sposet_ref.getName() << " is optimizable\n"; - if (!sposet_ref.isRotationSupported()) - myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet_ref.getName() + "' of type '" + - sposet_ref.getClassName() + "'."); - auto rot_spo = std::make_unique>(sposet_ref.getName(), std::move(sposet)); - xmlNodePtr tcur = cur->xmlChildrenNode; - while (tcur != NULL) + myComm->barrier_and_abort("Rotated SPO needs an SPOset"); + + if (!sposet->isRotationSupported()) + myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet->getName() + "' of type '" + + sposet->getClassName() + "'."); + + sposet->storeParamsBeforeRotation(); + auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); + + if (method == "history") + rot_spo->set_use_global_rotation(false); + + processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { + if (cname == "opt_vars") { - std::string cname((const char*)(tcur->name)); - if (cname == "opt_vars") - { - std::vector params; - putContent(params, tcur); - rot_spo->setRotationParameters(params); - } - tcur = tcur->next; + std::vector params; + putContent(params, element); + rot_spo->setRotationParameters(params); } - sposet = std::move(rot_spo); -#endif - } - - if (sposet->getName().empty()) - app_warning() << "SPOSet object doesn't have a name." << std::endl; - if (!spo_object_name.empty() && sposet->getName() != spo_object_name) - app_warning() << "SPOSet object name mismatched! input name: " << spo_object_name - << " object name: " << sposet->getName() << std::endl; - - sposet->checkObject(); - return sposet; + }); + return rot_spo; } -template -std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) +template<> +std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) { std::string spo_object_name; std::string method; @@ -143,12 +95,7 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu attrib.add(method, "method", {"global", "history"}); attrib.put(cur); - -#ifdef QMC_COMPLEX - myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); - return nullptr; -#else - std::unique_ptr> sposet; + std::unique_ptr> sposet; processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { if (cname == "sposet") { @@ -164,7 +111,7 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu sposet->getClassName() + "'."); sposet->storeParamsBeforeRotation(); - auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); + auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); if (method == "history") rot_spo->set_use_global_rotation(false); @@ -178,8 +125,34 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu } }); return rot_spo; -#endif } + +template<> +std::unique_ptr>> SPOSetBuilderT>::createRotatedSPOSet(xmlNodePtr cur) +{ + std::string spo_object_name; + std::string method; + OhmmsAttributeSet attrib; + attrib.add(spo_object_name, "name"); + attrib.add(method, "method", {"global", "history"}); + attrib.put(cur); + myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); + return nullptr; +} + +template<> +std::unique_ptr>> SPOSetBuilderT>::createRotatedSPOSet(xmlNodePtr cur) +{ + std::string spo_object_name; + std::string method; + OhmmsAttributeSet attrib; + attrib.add(spo_object_name, "name"); + attrib.add(method, "method", {"global", "history"}); + attrib.put(cur); + myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); + return nullptr; +} + template class SPOSetBuilderT; template class SPOSetBuilderT; template class SPOSetBuilderT>; diff --git a/src/QMCWaveFunctions/SPOSetT.cpp b/src/QMCWaveFunctions/SPOSetT.cpp index 34c76bad82..c20bda6513 100644 --- a/src/QMCWaveFunctions/SPOSetT.cpp +++ b/src/QMCWaveFunctions/SPOSetT.cpp @@ -359,8 +359,8 @@ void SPOSetT::evaluateDerivatives(ParticleSet& P, template void SPOSetT::evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const typename QTFull::ValueType& psiCurrent, + Vector& dlogpsi, + const ValueType& psiCurrent, const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, diff --git a/src/QMCWaveFunctions/SPOSetT.h b/src/QMCWaveFunctions/SPOSetT.h index ddc14c6593..6e12c3e929 100644 --- a/src/QMCWaveFunctions/SPOSetT.h +++ b/src/QMCWaveFunctions/SPOSetT.h @@ -179,8 +179,8 @@ class SPOSetT : public QMCTraits */ virtual void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const typename QTFull::ValueType& psiCurrent, + Vector& dlogpsi, + const ValueType& psiCurrent, const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn,