Skip to content

Commit

Permalink
Specialize functions in RotatedSPOsT
Browse files Browse the repository at this point in the history
Fix function signature
  • Loading branch information
williamfgc committed Aug 28, 2023
1 parent acb8862 commit eaae12d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 100 deletions.
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ void RotatedSPOsT<T>::evaluateDerivatives(ParticleSet& P,
template<typename T>
void RotatedSPOsT<T>::evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
7 changes: 4 additions & 3 deletions src/QMCWaveFunctions/RotatedSPOsT.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject
public:
using IndexType = typename SPOSetT<T>::IndexType;
using RealType = typename SPOSetT<T>::RealType;
using ValueType = typename SPOSetT<T>::ValueType;
using FullRealType = typename SPOSetT<T>::FullRealType;
using ValueVector = typename SPOSetT<T>::ValueVector;
using ValueMatrix = typename SPOSetT<T>::ValueMatrix;
Expand Down Expand Up @@ -200,9 +201,9 @@ class RotatedSPOsT : public SPOSetT<T>, public OptimizableObject

void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const FullRealType& psiCurrent,
const std::vector<T>& Coeff,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
const ValueVector& detValues_up,
Expand Down
153 changes: 63 additions & 90 deletions src/QMCWaveFunctions/SPOSetBuilderT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
#include "SPOSetBuilderT.h"
#include "OhmmsData/AttributeSet.h"
#include <Message/UniformCommunicateError.h>

#ifndef QMC_COMPLEX
#include "QMCWaveFunctions/RotatedSPOsT.h"
#endif
#include "QMCWaveFunctions/RotatedSPOsT.h" // only for real wavefunctions

namespace qmcplusplus
{
Expand All @@ -45,96 +42,51 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createSPOSet(xmlNodePtr cur, SPOS
return 0;
}

template<typename T>
std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createSPOSet(xmlNodePtr cur)

template<>
std::unique_ptr<SPOSetT<float>> SPOSetBuilderT<float>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string optimize("no");

std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "id");
attrib.add(spo_object_name, "name");
attrib.add(optimize, "optimize");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);

app_summary() << std::endl;
app_summary() << " Single particle orbitals (SPO)" << std::endl;
app_summary() << " ------------------------------" << std::endl;
app_summary() << " Name: " << spo_object_name << " Type: " << type_name_
<< " Builder class name: " << ClassName << std::endl;
app_summary() << std::endl;

if (spo_object_name.empty())
myComm->barrier_and_abort("SPOSet object \"name\" attribute not given in the input!");

// read specialized sposet construction requests
// and translate them into a set of orbital indices
SPOSetInputInfo input_info(cur);

// process general sposet construction requests
// and preserve legacy interface
std::unique_ptr<SPOSetT<T>> sposet;

try
{
if (legacy && input_info.legacy_request)
sposet = createSPOSetFromXML(cur);
else
sposet = createSPOSet(cur, input_info);
}
catch (const UniformCommunicateError& ue)
{
myComm->barrier_and_abort(ue.what());
}
std::unique_ptr<SPOSetT<float>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
sposet = createSPOSet(element);
}
});

if (!sposet)
myComm->barrier_and_abort("SPOSetBuilderT::createSPOSet sposet creation failed");

if (optimize == "rotation" || optimize == "yes")
{
#ifdef QMC_COMPLEX
app_error() << "Orbital optimization via rotation doesn't support complex wavefunction yet.\n";
abort();
#else
app_warning() << "Specifying orbital rotation via optimize tag is deprecated. Use the rotated_spo element instead"
<< std::endl;

sposet->storeParamsBeforeRotation();
// create sposet with rotation
auto& sposet_ref = *sposet;
app_log() << " SPOSet " << sposet_ref.getName() << " is optimizable\n";
if (!sposet_ref.isRotationSupported())
myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet_ref.getName() + "' of type '" +
sposet_ref.getClassName() + "'.");
auto rot_spo = std::make_unique<RotatedSPOsT<T>>(sposet_ref.getName(), std::move(sposet));
xmlNodePtr tcur = cur->xmlChildrenNode;
while (tcur != NULL)
myComm->barrier_and_abort("Rotated SPO needs an SPOset");

if (!sposet->isRotationSupported())
myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet->getName() + "' of type '" +
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<float>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);

processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "opt_vars")
{
std::string cname((const char*)(tcur->name));
if (cname == "opt_vars")
{
std::vector<RealType> params;
putContent(params, tcur);
rot_spo->setRotationParameters(params);
}
tcur = tcur->next;
std::vector<RealType> params;
putContent(params, element);
rot_spo->setRotationParameters(params);
}
sposet = std::move(rot_spo);
#endif
}

if (sposet->getName().empty())
app_warning() << "SPOSet object doesn't have a name." << std::endl;
if (!spo_object_name.empty() && sposet->getName() != spo_object_name)
app_warning() << "SPOSet object name mismatched! input name: " << spo_object_name
<< " object name: " << sposet->getName() << std::endl;

sposet->checkObject();
return sposet;
});
return rot_spo;
}

template<typename T>
std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cur)
template<>
std::unique_ptr<SPOSetT<double>> SPOSetBuilderT<double>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
Expand All @@ -143,12 +95,7 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);


#ifdef QMC_COMPLEX
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
#else
std::unique_ptr<SPOSetT<T>> sposet;
std::unique_ptr<SPOSetT<double>> sposet;
processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) {
if (cname == "sposet")
{
Expand All @@ -164,7 +111,7 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
sposet->getClassName() + "'.");

sposet->storeParamsBeforeRotation();
auto rot_spo = std::make_unique<RotatedSPOsT<T>>(spo_object_name, std::move(sposet));
auto rot_spo = std::make_unique<RotatedSPOsT<double>>(spo_object_name, std::move(sposet));

if (method == "history")
rot_spo->set_use_global_rotation(false);
Expand All @@ -178,8 +125,34 @@ std::unique_ptr<SPOSetT<T>> SPOSetBuilderT<T>::createRotatedSPOSet(xmlNodePtr cu
}
});
return rot_spo;
#endif
}

template<>
std::unique_ptr<SPOSetT<std::complex<float>>> SPOSetBuilderT<std::complex<float>>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
}

template<>
std::unique_ptr<SPOSetT<std::complex<double>>> SPOSetBuilderT<std::complex<double>>::createRotatedSPOSet(xmlNodePtr cur)
{
std::string spo_object_name;
std::string method;
OhmmsAttributeSet attrib;
attrib.add(spo_object_name, "name");
attrib.add(method, "method", {"global", "history"});
attrib.put(cur);
myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet.");
return nullptr;
}

template class SPOSetBuilderT<double>;
template class SPOSetBuilderT<float>;
template class SPOSetBuilderT<std::complex<double>>;
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ void SPOSetT<T>::evaluateDerivatives(ParticleSet& P,
template<class T>
void SPOSetT<T>::evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down
4 changes: 2 additions & 2 deletions src/QMCWaveFunctions/SPOSetT.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ class SPOSetT : public QMCTraits
*/
virtual void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<T>& dlogpsi,
const typename QTFull::ValueType& psiCurrent,
Vector<ValueType>& dlogpsi,
const ValueType& psiCurrent,
const std::vector<T>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down

0 comments on commit eaae12d

Please sign in to comment.