diff --git a/src/QMCWaveFunctions/RotatedSPOsT.cpp b/src/QMCWaveFunctions/RotatedSPOsT.cpp index 5a992ebce89..f76150ec2aa 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.cpp +++ b/src/QMCWaveFunctions/RotatedSPOsT.cpp @@ -975,9 +975,9 @@ void RotatedSPOsT::evaluateDerivatives(ParticleSet& P, template void RotatedSPOsT::evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const FullRealType& psiCurrent, - const std::vector& Coeff, + Vector& dlogpsi, + const ValueType& psiCurrent, + const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, const ValueVector& detValues_up, diff --git a/src/QMCWaveFunctions/RotatedSPOsT.h b/src/QMCWaveFunctions/RotatedSPOsT.h index 3273681455e..77daf7fd923 100644 --- a/src/QMCWaveFunctions/RotatedSPOsT.h +++ b/src/QMCWaveFunctions/RotatedSPOsT.h @@ -35,6 +35,7 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObject public: using IndexType = typename SPOSetT::IndexType; using RealType = typename SPOSetT::RealType; + using ValueType = typename SPOSetT::ValueType; using FullRealType = typename SPOSetT::FullRealType; using ValueVector = typename SPOSetT::ValueVector; using ValueMatrix = typename SPOSetT::ValueMatrix; @@ -200,9 +201,9 @@ class RotatedSPOsT : public SPOSetT, public OptimizableObject void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const FullRealType& psiCurrent, - const std::vector& Coeff, + Vector& dlogpsi, + const ValueType& psiCurrent, + const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, const ValueVector& detValues_up, diff --git a/src/QMCWaveFunctions/SPOSetBuilderT.cpp b/src/QMCWaveFunctions/SPOSetBuilderT.cpp index c682d6a77aa..80f8ee4b850 100644 --- a/src/QMCWaveFunctions/SPOSetBuilderT.cpp +++ b/src/QMCWaveFunctions/SPOSetBuilderT.cpp @@ -15,10 +15,7 @@ #include "SPOSetBuilderT.h" #include "OhmmsData/AttributeSet.h" #include - -#ifndef QMC_COMPLEX -#include "QMCWaveFunctions/RotatedSPOsT.h" -#endif +#include "QMCWaveFunctions/RotatedSPOsT.h" // only for real wavefunctions namespace qmcplusplus { @@ -133,8 +130,8 @@ std::unique_ptr> SPOSetBuilderT::createSPOSet(xmlNodePtr cur) return sposet; } -template -std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) +template<> +std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) { std::string spo_object_name; std::string method; @@ -143,12 +140,49 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu attrib.add(method, "method", {"global", "history"}); attrib.put(cur); + std::unique_ptr> sposet; + processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { + if (cname == "sposet") + { + sposet = createSPOSet(element); + } + }); + + if (!sposet) + myComm->barrier_and_abort("Rotated SPO needs an SPOset"); + + if (!sposet->isRotationSupported()) + myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet->getName() + "' of type '" + + sposet->getClassName() + "'."); + + sposet->storeParamsBeforeRotation(); + auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); + + if (method == "history") + rot_spo->set_use_global_rotation(false); -#ifdef QMC_COMPLEX - myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); - return nullptr; -#else - std::unique_ptr> sposet; + processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { + if (cname == "opt_vars") + { + std::vector params; + putContent(params, element); + rot_spo->setRotationParameters(params); + } + }); + return rot_spo; +} + +template<> +std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cur) +{ + std::string spo_object_name; + std::string method; + OhmmsAttributeSet attrib; + attrib.add(spo_object_name, "name"); + attrib.add(method, "method", {"global", "history"}); + attrib.put(cur); + + std::unique_ptr> sposet; processChildren(cur, [&](const std::string& cname, const xmlNodePtr element) { if (cname == "sposet") { @@ -164,7 +198,7 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu sposet->getClassName() + "'."); sposet->storeParamsBeforeRotation(); - auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); + auto rot_spo = std::make_unique>(spo_object_name, std::move(sposet)); if (method == "history") rot_spo->set_use_global_rotation(false); @@ -178,8 +212,34 @@ std::unique_ptr> SPOSetBuilderT::createRotatedSPOSet(xmlNodePtr cu } }); return rot_spo; -#endif } + +template<> +std::unique_ptr>> SPOSetBuilderT>::createRotatedSPOSet(xmlNodePtr cur) +{ + std::string spo_object_name; + std::string method; + OhmmsAttributeSet attrib; + attrib.add(spo_object_name, "name"); + attrib.add(method, "method", {"global", "history"}); + attrib.put(cur); + myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); + return nullptr; +} + +template<> +std::unique_ptr>> SPOSetBuilderT>::createRotatedSPOSet(xmlNodePtr cur) +{ + std::string spo_object_name; + std::string method; + OhmmsAttributeSet attrib; + attrib.add(spo_object_name, "name"); + attrib.add(method, "method", {"global", "history"}); + attrib.put(cur); + myComm->barrier_and_abort("Orbital optimization via rotation doesn't support complex wavefunctions yet."); + return nullptr; +} + template class SPOSetBuilderT; template class SPOSetBuilderT; template class SPOSetBuilderT>; diff --git a/src/QMCWaveFunctions/SPOSetT.cpp b/src/QMCWaveFunctions/SPOSetT.cpp index 34c76bad821..c20bda6513e 100644 --- a/src/QMCWaveFunctions/SPOSetT.cpp +++ b/src/QMCWaveFunctions/SPOSetT.cpp @@ -359,8 +359,8 @@ void SPOSetT::evaluateDerivatives(ParticleSet& P, template void SPOSetT::evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const typename QTFull::ValueType& psiCurrent, + Vector& dlogpsi, + const ValueType& psiCurrent, const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn, diff --git a/src/QMCWaveFunctions/SPOSetT.h b/src/QMCWaveFunctions/SPOSetT.h index ddc14c65937..6e12c3e9299 100644 --- a/src/QMCWaveFunctions/SPOSetT.h +++ b/src/QMCWaveFunctions/SPOSetT.h @@ -179,8 +179,8 @@ class SPOSetT : public QMCTraits */ virtual void evaluateDerivativesWF(ParticleSet& P, const opt_variables_type& optvars, - Vector& dlogpsi, - const typename QTFull::ValueType& psiCurrent, + Vector& dlogpsi, + const ValueType& psiCurrent, const std::vector& Coeff, const std::vector& C2node_up, const std::vector& C2node_dn,