Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SPOSet a class template #4685

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/Configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ struct QMCTraits
{
enum
{
DIM = OHMMS_DIM,
DIM_VGL = OHMMS_DIM + 2 // Value(1) + Gradients(OHMMS_DIM) + Laplacian(1)
DIM = OHMMS_DIM
};
using QTBase = QMCTypes<OHMMS_PRECISION, DIM>;
using QTFull = QMCTypes<OHMMS_PRECISION_FULL, DIM>;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/BsplineFactory/SplineR2R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void SplineR2R<ST>::applyRotation(const ValueMatrix& rot_mat, bool use_stored_co
for (IndexType j = 0; j < OrbitalSetSize; j++)
{
const auto cur_elem = Nsplines * i + j;
FullPrecValueType newval{0.};
FullPrecValue newval{0.};
for (IndexType k = 0; k < OrbitalSetSize; k++)
{
const auto index = i * Nsplines + k;
Expand Down
6 changes: 3 additions & 3 deletions src/QMCWaveFunctions/Fermion/DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_ratioGrad(const RefVectorWithLeader

auto psiMinv_row_dev_ptr_list = DET_ENGINE::mw_getInvRow(engine_list, WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
phi_vgl_v.resize(SPOSet::DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
grad_new_local.resize(wfc_list.size());

Expand Down Expand Up @@ -391,7 +391,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_ratioGradWithSpin(

auto psiMinv_row_dev_ptr_list = DET_ENGINE::mw_getInvRow(engine_list, WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
phi_vgl_v.resize(SPOSet::DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
grad_new_local.resize(wfc_list.size());
spingrad_new_local.resize(wfc_list.size());
Expand Down Expand Up @@ -784,7 +784,7 @@ void DiracDeterminantBatched<DET_ENGINE>::mw_calcRatio(const RefVectorWithLeader

auto psiMinv_row_dev_ptr_list = DET_ENGINE::mw_getInvRow(engine_list, WorkingIndex, !Phi->isOMPoffload());

phi_vgl_v.resize(DIM_VGL, wfc_list.size(), NumOrbitals);
phi_vgl_v.resize(SPOSet::DIM_VGL, wfc_list.size(), NumOrbitals);
ratios_local.resize(wfc_list.size());
grad_new_local.resize(wfc_list.size());

Expand Down
5 changes: 4 additions & 1 deletion src/QMCWaveFunctions/Fermion/SlaterDetBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class BackflowTransformation;
class DiracDeterminantBase;
class MultiSlaterDetTableMethod;
struct CSFData;
class SPOSet;

template<typename VALUE>
class SPOSetT;
using SPOSet = SPOSetT<QMCTraits::QTBase::ValueType>;
class SPOSetBuilder;
class SPOSetBuilderFactory;
struct ci_configuration;
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ void RotatedSPOs::evaluateDerivatives(ParticleSet& P,
void RotatedSPOs::evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<ValueType>& dlogpsi,
const QTFull::ValueType& psiCurrent,
const FullPrecValue& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/RotatedSPOs.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class RotatedSPOs : public SPOSet, public OptimizableObject
void evaluateDerivativesWF(ParticleSet& P,
const opt_variables_type& optvars,
Vector<ValueType>& dlogpsi,
const QTFull::ValueType& psiCurrent,
const FullPrecValue& psiCurrent,
const std::vector<ValueType>& Coeff,
const std::vector<size_t>& C2node_up,
const std::vector<size_t>& C2node_dn,
Expand Down
419 changes: 231 additions & 188 deletions src/QMCWaveFunctions/SPOSet.cpp

Large diffs are not rendered by default.

172 changes: 99 additions & 73 deletions src/QMCWaveFunctions/SPOSet.h

Large diffs are not rendered by default.

41 changes: 30 additions & 11 deletions src/QMCWaveFunctions/tests/FakeSPO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
namespace qmcplusplus
{

FakeSPO::FakeSPO() : SPOSet("one_FakeSPO")
template<typename VALUE>
FakeSPO<VALUE>::FakeSPO() : SPOSet("one_FakeSPO")
{
a.resize(3, 3);

Expand Down Expand Up @@ -78,11 +79,20 @@ FakeSPO::FakeSPO() : SPOSet("one_FakeSPO")
gv[3] = TinyVector<ValueType, DIM>(0.4, 0.3, 0.1);
}

std::unique_ptr<SPOSet> FakeSPO::makeClone() const { return std::make_unique<FakeSPO>(*this); }
template<typename VALUE>
std::unique_ptr<SPOSetT<VALUE>> FakeSPO<VALUE>::makeClone() const
{
return std::make_unique<FakeSPO>(*this);
}

void FakeSPO::setOrbitalSetSize(int norbs) { OrbitalSetSize = norbs; }
template<typename VALUE>
void FakeSPO<VALUE>::setOrbitalSetSize(int norbs)
{
OrbitalSetSize = norbs;
}

void FakeSPO::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
template<typename VALUE>
void FakeSPO<VALUE>::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
{
if (iat < 0)
for (int i = 0; i < psi.size(); i++)
Expand All @@ -95,7 +105,8 @@ void FakeSPO::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)
psi[i] = a2(iat, i);
}

void FakeSPO::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
template<typename VALUE>
void FakeSPO<VALUE>::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradVector& dpsi, ValueVector& d2psi)
{
if (OrbitalSetSize == 3)
{
Expand All @@ -115,12 +126,13 @@ void FakeSPO::evaluateVGL(const ParticleSet& P, int iat, ValueVector& psi, GradV
}
}

void FakeSPO::evaluate_notranspose(const ParticleSet& P,
int first,
int last,
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet)
template<typename VALUE>
void FakeSPO<VALUE>::evaluate_notranspose(const ParticleSet& P,
int first,
int last,
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet)
{
if (OrbitalSetSize == 3)
{
Expand All @@ -142,4 +154,11 @@ void FakeSPO::evaluate_notranspose(const ParticleSet& P,
}
}

#if !defined(MIXED_PRECISION)
template class FakeSPO<double>;
template class FakeSPO<std::complex<double>>;
#endif
template class FakeSPO<float>;
template class FakeSPO<std::complex<float>>;

} // namespace qmcplusplus
27 changes: 25 additions & 2 deletions src/QMCWaveFunctions/tests/FakeSPO.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,28 @@
namespace qmcplusplus
{

class FakeSPO : public SPOSet
template<typename VALUE>
class FakeSPO : public SPOSetT<VALUE>
{
public:
enum
{
DIM = OHMMS_DIM,
};
using SPOSet = SPOSetT<VALUE>;
using ValueType = typename SPOSet::ValueType;
using GradType = typename SPOSet::GradType;
using ValueVector = typename SPOSet::ValueVector;
using GradVector = typename SPOSet::GradVector;
using ValueMatrix = typename SPOSet::ValueMatrix;
using GradMatrix = typename SPOSet::GradMatrix;

Matrix<ValueType> a;
Matrix<ValueType> a2;
Vector<ValueType> v;
Matrix<ValueType> v2;

SPOSet::GradVector gv;
GradVector gv;

FakeSPO();
~FakeSPO() override {}
Expand All @@ -47,7 +60,17 @@ class FakeSPO : public SPOSet
ValueMatrix& logdet,
GradMatrix& dlogdet,
ValueMatrix& d2logdet) override;

private:
using SPOSet::OrbitalSetSize;
};

#if !defined(MIXED_PRECISION)
extern template class FakeSPO<double>;
extern template class FakeSPO<std::complex<double>>;
#endif
extern template class FakeSPO<float>;
extern template class FakeSPO<std::complex<float>>;

} // namespace qmcplusplus
#endif
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/tests/test_DiracDeterminant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ void check_matrix(Matrix<T1>& a, Matrix<T2>& b)
template<typename DET>
void test_DiracDeterminant_first(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 3;
spo_init->setOrbitalSetSize(norb);
DET ddb(std::move(spo_init), 0, norb, 1, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -159,11 +159,11 @@ TEST_CASE("DiracDeterminant_first", "[wavefunction][fermion]")
template<typename DET>
void test_DiracDeterminant_second(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DET ddb(std::move(spo_init), 0, norb, 1, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -300,12 +300,12 @@ TEST_CASE("DiracDeterminant_second", "[wavefunction][fermion]")
template<typename DET>
void test_DiracDeterminant_delayed_update(const DetMatInvertor inverter_kind)
{
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
// maximum delay 2
DET ddc(std::move(spo_init), 0, norb, 2, inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddc.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddc.getPhi());

// occurs in call to registerData
ddc.dpsiV.resize(norb);
Expand Down
12 changes: 6 additions & 6 deletions src/QMCWaveFunctions/tests/test_DiracDeterminantBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_first()
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 3;
spo_init->setOrbitalSetSize(norb);
DetType ddb(std::move(spo_init), 0, norb);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -141,11 +141,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_second()
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DetType ddb(std::move(spo_init), 0, norb);
auto spo = dynamic_cast<FakeSPO*>(ddb.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddb.getPhi());

// occurs in call to registerData
ddb.dpsiV.resize(norb);
Expand Down Expand Up @@ -277,11 +277,11 @@ template<class DET_ENGINE>
void test_DiracDeterminantBatched_delayed_update(int delay_rank, DetMatInvertor matrix_inverter_kind)
{
using DetType = DiracDeterminantBatched<DET_ENGINE>;
auto spo_init = std::make_unique<FakeSPO>();
auto spo_init = std::make_unique<FakeSPO<ValueType>>();
const int norb = 4;
spo_init->setOrbitalSetSize(norb);
DetType ddc(std::move(spo_init), 0, norb, delay_rank, matrix_inverter_kind);
auto spo = dynamic_cast<FakeSPO*>(ddc.getPhi());
auto spo = dynamic_cast<FakeSPO<ValueType>*>(ddc.getPhi());

// occurs in call to registerData
ddc.dpsiV.resize(norb);
Expand Down
11 changes: 6 additions & 5 deletions src/QMCWaveFunctions/tests/test_RotatedSPOs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,15 +644,16 @@ TEST_CASE("RotatedSPOs construct delta matrix", "[wavefunction]")

namespace testing
{
opt_variables_type& getMyVars(SPOSet& rot) { return rot.myVars; }
template<typename Value>
opt_variables_type& getMyVars(SPOSetT<Value>& rot) { return rot.myVars; }
opt_variables_type& getMyVarsFull(RotatedSPOs& rot) { return rot.myVarsFull; }
std::vector<std::vector<QMCTraits::RealType>>& getHistoryParams(RotatedSPOs& rot) { return rot.history_params_; }
} // namespace testing

// Test using global rotation
TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
{
auto fake_spo = std::make_unique<FakeSPO>();
auto fake_spo = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo->setOrbitalSetSize(4);
RotatedSPOs rot("fake_rot", std::move(fake_spo));
int nel = 2;
Expand All @@ -673,7 +674,7 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
rot.writeVariationalParameters(hout);
}

auto fake_spo2 = std::make_unique<FakeSPO>();
auto fake_spo2 = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo2->setOrbitalSetSize(4);

RotatedSPOs rot2("fake_rot", std::move(fake_spo2));
Expand Down Expand Up @@ -704,7 +705,7 @@ TEST_CASE("RotatedSPOs read and write parameters", "[wavefunction]")
// Test using history list.
TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
{
auto fake_spo = std::make_unique<FakeSPO>();
auto fake_spo = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo->setOrbitalSetSize(4);
RotatedSPOs rot("fake_rot", std::move(fake_spo));
rot.set_use_global_rotation(false);
Expand All @@ -726,7 +727,7 @@ TEST_CASE("RotatedSPOs read and write parameters history", "[wavefunction]")
rot.writeVariationalParameters(hout);
}

auto fake_spo2 = std::make_unique<FakeSPO>();
auto fake_spo2 = std::make_unique<FakeSPO<QMCTraits::ValueType>>();
fake_spo2->setOrbitalSetSize(4);

RotatedSPOs rot2("fake_rot", std::move(fake_spo2));
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_einset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ TEST_CASE("Einspline SPO from HDF diamond_2x1x1 5 electrons", "[wavefunction]")
std::vector<const SPOSet::ValueType*> inv_row_ptr(nw, inv_row.device_data());

SPOSet::OffloadMWVGLArray phi_vgl_v;
phi_vgl_v.resize(QMCTraits::DIM_VGL, nw, 5);
phi_vgl_v.resize(SPOSet::DIM_VGL, nw, 5);
spo->mw_evaluateVGLandDetRatioGrads(spo_list, p_list, 0, inv_row_ptr, phi_vgl_v, ratio_v, grads_v);
#if !defined(QMC_COMPLEX)
CHECK(std::real(ratio_v[0]) == Approx(0.2365307168));
Expand Down
2 changes: 1 addition & 1 deletion src/QMCWaveFunctions/tests/test_einset_NiO_a16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ TEST_CASE("Einspline SPO from HDF NiO a16 97 electrons", "[wavefunction]")
std::vector<const SPOSet::ValueType*> inv_row_ptr(nw, inv_row.device_data());

SPOSet::OffloadMWVGLArray phi_vgl_v;
phi_vgl_v.resize(QMCTraits::DIM_VGL, nw, 5);
phi_vgl_v.resize(SPOSet::DIM_VGL, nw, 5);
spo->mw_evaluateVGLandDetRatioGrads(spo_list, p_list, 0, inv_row_ptr, phi_vgl_v, ratio_v, grads_v);
phi_vgl_v.updateFrom();
#if !defined(QMC_COMPLEX)
Expand Down